Skip to content

Commit 3cc0e92

Browse files
committed
deepseek
1 parent 5045121 commit 3cc0e92

File tree

25 files changed

+1013
-141
lines changed

25 files changed

+1013
-141
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 214 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,46 @@
1212

1313

1414
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
15-
# Lightweight random prompt generator to vary prompt lengths and content.
16-
vocab = [
17-
"alpha",
18-
"bravo",
19-
"charlie",
20-
"delta",
21-
"echo",
22-
"foxtrot",
23-
"golf",
24-
"hotel",
25-
"india",
26-
"juliet",
27-
"kilo",
28-
"lima",
29-
"mike",
30-
"november",
31-
"oscar",
32-
"papa",
33-
"quebec",
34-
"romeo",
35-
"sierra",
36-
"tango",
37-
"uniform",
38-
"victor",
39-
"whiskey",
40-
"xray",
41-
"yankee",
42-
"zulu",
15+
# Generate more realistic prompts that will actually produce varied tokens
16+
# Use a mix of common English text patterns
17+
18+
prompt_templates = [
19+
# Question-answer style
20+
"Question: What is the capital of France?\nAnswer: The capital of France is",
21+
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
22+
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
23+
24+
# Story/narrative style
25+
"Once upon a time in a distant galaxy, there lived",
26+
"The old man walked slowly down the street, remembering",
27+
"In the year 2157, humanity finally discovered",
28+
29+
# Technical/code style
30+
"To implement a binary search tree in Python, first we need to",
31+
"The algorithm works by iterating through the array and",
32+
"Here's how to optimize database queries using indexing:",
33+
34+
# Factual/informative style
35+
"The Renaissance was a period in European history that",
36+
"Climate change is caused by several factors including",
37+
"The human brain contains approximately 86 billion neurons which",
38+
39+
# Conversational style
40+
"I've been thinking about getting a new laptop because",
41+
"Yesterday I went to the store and bought",
42+
"My favorite thing about summer is definitely",
4343
]
44-
n = random.randint(min_words, max_words)
45-
words = random.choices(vocab, k=n)
4644

47-
# Add some noise and punctuation variability
48-
if random.random() < 0.5:
49-
words[0] = words[0].capitalize()
50-
if random.random() < 0.2:
51-
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
52-
punct = random.choice([".", "?", "!", "...", ""])
53-
return " ".join(words) + punct
45+
# Pick a random template
46+
base_prompt = random.choice(prompt_templates)
47+
48+
# Add some padding to vary the length if needed
49+
if min_words > 50:
50+
# For longer prompts, repeat context
51+
padding_text = " This is an interesting topic that deserves more explanation. " * (min_words // 50)
52+
base_prompt = base_prompt + padding_text
53+
54+
return base_prompt
5455

5556

5657
@pytest.mark.timeout(1000)
@@ -91,7 +92,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
9192
# Keep GPU memory usage low to avoid startup allocation failures.
9293
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
9394
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
94-
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
9595

9696
# Sampling parameters: longer outputs with a more random-sounding
9797
# continuation,but still deterministic due to fixed seed.
@@ -117,7 +117,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
117117
max_num_seqs=max_batch_size,
118118
gpu_memory_utilization=gpu_mem_util,
119119
max_model_len=max_model_len,
120-
swap_space=swap_space_gb,
121120
)
122121

123122
# Baseline generation for the needle prompt alone.
@@ -132,7 +131,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
132131
max_num_seqs=max_batch_size,
133132
gpu_memory_utilization=gpu_mem_util,
134133
max_model_len=max_model_len,
135-
swap_space=swap_space_gb,
136134
)
137135

138136
mismatches = 0
@@ -195,16 +193,16 @@ def _extract_step_logprobs(request_output):
195193
],
196194
dtype=torch.float32,
197195
)
198-
return t
196+
return t, inner.token_ids
199197

200-
return None
198+
return None, None
201199

202200

203201
@pytest.mark.skipif(
204202
not torch.cuda.is_available(),
205203
reason="Requires CUDA to match production inference path.",
206204
)
207-
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
205+
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASH_ATTN", "FLASHINFER"])
208206
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
209207
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
210208
os.environ["VLLM_ATTENTION_BACKEND"] = backend
@@ -214,77 +212,226 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
214212
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
215213
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
216214

217-
# Force float32 to avoid precision-induced differences.
215+
# For batch invariance, disable custom all-reduce to ensure deterministic
216+
# all-reduce operations (custom all-reduce may not be deterministic)
217+
from vllm.model_executor.layers.batch_invariant import (
218+
vllm_kernel_override_batch_invariant,
219+
)
220+
disable_custom_ar = vllm_kernel_override_batch_invariant()
221+
222+
if disable_custom_ar:
223+
print(f"\n{'='*80}")
224+
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
225+
print(f"{'='*80}\n")
226+
218227
llm = LLM(
219228
model=model_name,
220229
tensor_parallel_size=tp_size,
221-
enforce_eager=True,
222230
enable_prefix_caching=False,
231+
max_num_seqs=32,
232+
max_model_len=8192,
233+
dtype="bfloat16", # not everything is supported
223234
)
224235

225-
prompts = [_random_prompt(10, 1024) for i in range(100)]
236+
# Use more realistic prompts for better token generation
237+
prompts = [_random_prompt(10, 50) for i in range(3)]
226238

227239
sp = SamplingParams(
228240
temperature=0.6,
229241
top_p=1.0,
230242
max_tokens=8,
231-
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
232243
seed=1234,
233244
logprobs=5,
234245
)
235246

236247
# BS=1: run prompts individually and collect logprobs per step.
248+
print("\n" + "="*80)
249+
print("STARTING BS=1 RUNS (each prompt individually)")
250+
print("="*80 + "\n")
251+
237252
bs1_logprobs_per_prompt = []
238-
for p in prompts:
253+
bs1_tokens_per_prompt = []
254+
for idx, p in enumerate(prompts):
255+
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
239256
outs = llm.generate([p], sp, use_tqdm=False)
240257
assert len(outs) == 1
241-
step_logprobs = _extract_step_logprobs(outs[0])
258+
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
242259
if step_logprobs is None:
243260
pytest.skip(
244261
"Logits are not available on RequestOutput; "
245262
"enable logprobs return to run this test."
246263
)
247264
bs1_logprobs_per_prompt.append(step_logprobs)
265+
bs1_tokens_per_prompt.append(token_ids)
266+
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
248267

249268
# BS=N: run prompts in a batch and collect logprobs per step for each
250269
# prompt.
270+
print("\n" + "="*80)
271+
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
272+
print("="*80 + "\n")
273+
251274
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
252275
assert len(outs_batched) == len(prompts)
253276
bsN_logprobs_per_prompt = []
254-
for o in outs_batched:
255-
step_logprobs = _extract_step_logprobs(o)
277+
bsN_tokens_per_prompt = []
278+
279+
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
280+
for idx, o in enumerate(outs_batched):
281+
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {o.outputs[0].token_ids if o.outputs else 'N/A'}")
282+
step_logprobs, token_ids = _extract_step_logprobs(o)
256283
if step_logprobs is None:
257284
pytest.skip(
258285
"Logits are not available on RequestOutput; "
259286
"enable logprobs return to run this test."
260287
)
261288
bsN_logprobs_per_prompt.append(step_logprobs)
289+
bsN_tokens_per_prompt.append(token_ids)
262290

263291
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
264-
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
265-
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)
292+
failed_prompts = []
293+
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
294+
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt,
295+
bs1_tokens_per_prompt, bsN_tokens_per_prompt)
266296
):
267-
assert len(logprobs_bs1) == len(logprobs_bsN), (
268-
f"Different number of generation steps for prompt index {i}: "
269-
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
270-
)
297+
if len(logprobs_bs1) != len(logprobs_bsN):
298+
failed_prompts.append({
299+
"prompt_idx": i,
300+
"step": "all",
301+
"reason": f"Different number of steps: {len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)",
302+
"prompt_preview": prompts[i][:100],
303+
"bs1_tokens": tokens_bs1,
304+
"bsN_tokens": tokens_bsN,
305+
})
306+
continue
307+
308+
# Check if tokens match first
309+
if tokens_bs1 != tokens_bsN:
310+
failed_prompts.append({
311+
"prompt_idx": i,
312+
"step": "sampling",
313+
"reason": "Different tokens sampled",
314+
"prompt_preview": prompts[i][:100],
315+
"bs1_tokens": tokens_bs1,
316+
"bsN_tokens": tokens_bsN,
317+
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
318+
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
319+
})
320+
continue
321+
271322
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
272-
assert a.shape == b.shape, (
273-
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
274-
)
275-
# Bitwise exact equality.
276-
assert torch.equal(a, b), (
277-
f"Bitwise logprobs mismatch at prompt {i}, step {t} "
278-
f"(dtype={a.dtype}, shape={a.shape})."
279-
)
323+
if a.shape != b.shape:
324+
failed_prompts.append({
325+
"prompt_idx": i,
326+
"step": t,
327+
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
328+
"prompt_preview": prompts[i][:100],
329+
"bs1_tokens": tokens_bs1,
330+
"bsN_tokens": tokens_bsN,
331+
})
332+
break
333+
334+
if not torch.equal(a, b):
335+
max_diff = torch.abs(a - b).max().item()
336+
# Print which token failed
337+
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
338+
print(f" Token IDs: bs1={tokens_bs1[t] if t < len(tokens_bs1) else 'N/A'}, bsN={tokens_bsN[t] if t < len(tokens_bsN) else 'N/A'}")
339+
print(f" BS=1 logprob: {a.tolist()}")
340+
print(f" BS=N logprob: {b.tolist()}")
341+
failed_prompts.append({
342+
"prompt_idx": i,
343+
"step": t,
344+
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
345+
"prompt_preview": prompts[i][:100],
346+
"bs1_tokens": tokens_bs1,
347+
"bsN_tokens": tokens_bsN,
348+
"bs1_all_logprobs": [logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))],
349+
"bsN_all_logprobs": [logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))],
350+
})
351+
break
352+
353+
# Print summary of all failures
354+
if failed_prompts:
355+
print(f"\n{'='*80}")
356+
print(f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/{len(prompts)} prompts failed")
357+
print(f"{'='*80}")
358+
for fail in failed_prompts:
359+
print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
360+
print(f" Reason: {fail['reason']}")
361+
print(f" Preview: {fail['prompt_preview']}...")
362+
363+
# Always show the tokens
364+
if "bs1_tokens" in fail:
365+
print(f" BS=1 tokens: {fail['bs1_tokens']}")
366+
if "bsN_tokens" in fail:
367+
print(f" BS=N tokens: {fail['bsN_tokens']}")
368+
369+
if "bs1_all_logprobs" in fail:
370+
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
371+
for step_idx, logprobs in enumerate(fail['bs1_all_logprobs']):
372+
print(f" Step {step_idx}: {logprobs}")
373+
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
374+
for step_idx, logprobs in enumerate(fail['bsN_all_logprobs']):
375+
print(f" Step {step_idx}: {logprobs}")
376+
print(f"{'='*80}\n")
377+
378+
# Fail the test with summary
379+
pytest.fail(
380+
f"Batch invariance violated in {len(failed_prompts)}/{len(prompts)} prompts. "
381+
f"See output above for details."
382+
)
383+
384+
385+
def test_simple_generation():
386+
"""
387+
Simple test that runs the model with a basic prompt and prints the output.
388+
Useful for quick smoke testing and debugging.
389+
"""
390+
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
391+
392+
llm = LLM(
393+
model=model,
394+
max_num_seqs=1,
395+
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
396+
enforce_eager=True,
397+
gpu_memory_utilization=0.9,
398+
max_model_len=2048,
399+
dtype="bfloat16",
400+
enable_prefix_caching=False,
401+
)
402+
403+
prompt = "the capital of france is"
404+
sampling_params = SamplingParams(
405+
temperature=0.0,
406+
max_tokens=20,
407+
)
408+
409+
print(f"\n{'='*80}")
410+
print(f"Running simple generation test")
411+
print(f"Prompt: '{prompt}'")
412+
print(f"{'='*80}\n")
413+
414+
try:
415+
outputs = llm.generate([prompt], sampling_params)
416+
417+
assert len(outputs) == 1
418+
output_text = outputs[0].outputs[0].text
419+
420+
print(f"Output: '{output_text}'")
421+
print(f"\n{'='*80}")
422+
print(f"Full completion: '{prompt}{output_text}'")
423+
print(f"{'='*80}\n")
424+
425+
finally:
426+
with contextlib.suppress(Exception):
427+
llm.shutdown()
280428

281429

282430
def LLM_with_max_seqs(
283431
model: str,
284432
max_num_seqs: int,
285433
gpu_memory_utilization: float,
286434
max_model_len: int,
287-
swap_space: int,
288435
) -> LLM:
289436
"""
290437
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
@@ -293,17 +440,9 @@ def LLM_with_max_seqs(
293440
return LLM(
294441
model=model,
295442
max_num_seqs=max_num_seqs,
296-
# Constrain GPU memory pool so test can run even on busy GPUs.
297443
gpu_memory_utilization=gpu_memory_utilization,
298-
# Keep KV cache footprint small while allowing longer outputs.
299444
max_model_len=max_model_len,
300-
# Allow some CPU offload if needed.
301-
swap_space=swap_space,
302-
# Keep things lean and CI-friendly.
303-
dtype="float16",
304-
# Single-GPU by default; override externally if desired.
445+
dtype="bfloat16",
305446
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
306-
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
307-
enforce_eager=True,
308447
enable_prefix_caching=False,
309448
)

0 commit comments

Comments
 (0)