Skip to content

Commit 355bda6

Browse files
committed
Fixup imports further, ignore 'frozen' imports in compilation, fix
chunked prefill setting Signed-off-by: Bram Wasti <bwasti@meta.com>
1 parent 2d330a7 commit 355bda6

File tree

7 files changed

+287
-20
lines changed

7 files changed

+287
-20
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,282 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
707707
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
708708

709709

710+
@pytest.mark.skipif(
711+
not current_platform.has_device_capability(90),
712+
reason="Batch invariance tests only supported on Hopper (SM90)",
713+
)
714+
@pytest.mark.skipif(
715+
not torch.cuda.is_available(),
716+
reason="Requires CUDA to match production inference path.",
717+
)
718+
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
719+
@pytest.mark.forked
720+
def test_decode_logprobs_match_prefill_logprobs(backend):
721+
"""
722+
Test that verifies decode logprobs match prefill logprobs.
723+
724+
For each decoded token at position i:
725+
1. Run decode to generate N tokens and collect their logprobs
726+
2. For each position i in [0, N):
727+
- Take prefix = prompt + tokens[0:i]
728+
- Run prefill(prefix + tokens[i]) to get logprob of tokens[i]
729+
- Verify prefill logprob matches decode logprob bitwise
730+
731+
This ensures that the logprobs from decode are consistent with what
732+
we would get if we ran prefill on each prefix.
733+
"""
734+
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
735+
os.environ["VLLM_ATTENTION_BACKEND"] = backend
736+
737+
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
738+
random.seed(seed)
739+
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
740+
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
741+
742+
from vllm.model_executor.layers.batch_invariant import (
743+
vllm_kernel_override_batch_invariant,
744+
)
745+
746+
disable_custom_ar = vllm_kernel_override_batch_invariant()
747+
748+
if disable_custom_ar:
749+
print(f"\n{'=' * 80}")
750+
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
751+
print(f"{'=' * 80}\n")
752+
753+
llm = LLM(
754+
model=model_name,
755+
tensor_parallel_size=tp_size,
756+
enable_prefix_caching=False,
757+
max_num_seqs=32,
758+
max_model_len=8192,
759+
dtype="bfloat16",
760+
)
761+
762+
# Use a few test prompts
763+
num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4"))
764+
prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)]
765+
766+
# Generate longer sequences to test multiple decode steps
767+
max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16"))
768+
769+
sp = SamplingParams(
770+
temperature=0.0, # Greedy for determinism
771+
max_tokens=max_tokens,
772+
logprobs=5,
773+
)
774+
775+
print("\n" + "=" * 80)
776+
print("STEP 1: Running decode to generate tokens and collect logprobs")
777+
print("=" * 80 + "\n")
778+
779+
# Step 1: Run decode and collect logprobs
780+
decode_outputs = llm.generate(prompts, sp, use_tqdm=False)
781+
782+
failed_comparisons = []
783+
784+
for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)):
785+
print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...")
786+
787+
# Extract decode logprobs and tokens
788+
decode_logprobs, token_ids = _extract_step_logprobs(decode_output)
789+
if decode_logprobs is None:
790+
pytest.skip(
791+
"Logprobs are not available on RequestOutput; "
792+
"enable logprobs return to run this test."
793+
)
794+
795+
print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}")
796+
print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}")
797+
798+
# Step 2: For each token position, run prefill and compare
799+
print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...")
800+
801+
for token_idx in range(len(token_ids)):
802+
# Construct the prefix up to (but not including) this token
803+
current_token = token_ids[token_idx]
804+
805+
# We need to detokenize to get the text prefix
806+
# For this, we'll use the tokenizer from the LLM
807+
# However, the LLM API doesn't expose tokenizer easily, so we'll
808+
# construct the prefix by decoding from the original prompt
809+
810+
# Get text up to this point by using the output text
811+
# This is approximate but should work for verification
812+
if token_idx == 0:
813+
prefix_prompt = prompt
814+
else:
815+
# Use the partial output text up to this token
816+
# We'll need to construct this from the full output
817+
prefix_output = decode_output.outputs[0]
818+
# Get the text for tokens 0 to token_idx-1
819+
# Unfortunately, we don't have per-token text, so we'll use
820+
# a different approach: run prefill with prompt + tokens[0:token_idx]
821+
822+
# Actually, we need to get the actual text. Let's use a workaround:
823+
# Run a generation with max_tokens = token_idx to get that prefix
824+
prefix_sp = SamplingParams(
825+
temperature=0.0,
826+
max_tokens=token_idx,
827+
logprobs=1,
828+
)
829+
prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0]
830+
prefix_prompt = prompt + prefix_output.outputs[0].text
831+
832+
# Now run prefill with max_tokens=1 to get the logprob of the next token
833+
prefill_sp = SamplingParams(
834+
temperature=0.0,
835+
max_tokens=1,
836+
logprobs=5,
837+
)
838+
839+
print(
840+
f" [Token {token_idx}] Running prefill for prefix "
841+
f"(len={len(prefix_prompt)})..."
842+
)
843+
prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[
844+
0
845+
]
846+
prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output)
847+
848+
if prefill_logprobs is None:
849+
print(f" [Token {token_idx}] Warning: No prefill logprobs available")
850+
continue
851+
852+
# The first token from prefill should match the current token
853+
prefill_token = prefill_token_ids[0]
854+
prefill_logprob = prefill_logprobs[0].item()
855+
decode_logprob = decode_logprobs[token_idx].item()
856+
857+
print(
858+
f" [Token {token_idx}] Decode token: {current_token}, "
859+
f"logprob: {decode_logprob:.8f}"
860+
)
861+
print(
862+
f" [Token {token_idx}] Prefill token: {prefill_token}, "
863+
f"logprob: {prefill_logprob:.8f}"
864+
)
865+
866+
# Check if tokens match
867+
if current_token != prefill_token:
868+
failed_comparisons.append(
869+
{
870+
"prompt_idx": prompt_idx,
871+
"token_idx": token_idx,
872+
"reason": "Token mismatch",
873+
"decode_token": current_token,
874+
"prefill_token": prefill_token,
875+
"decode_logprob": decode_logprob,
876+
"prefill_logprob": prefill_logprob,
877+
"prompt_text": prompt[:100],
878+
"prefix_text": prefix_prompt[:100],
879+
}
880+
)
881+
print(f" [Token {token_idx}] ✗ TOKEN MISMATCH!")
882+
continue
883+
884+
# Check if logprobs match bitwise
885+
if decode_logprob != prefill_logprob:
886+
diff = abs(decode_logprob - prefill_logprob)
887+
failed_comparisons.append(
888+
{
889+
"prompt_idx": prompt_idx,
890+
"token_idx": token_idx,
891+
"reason": "Logprob mismatch",
892+
"decode_token": current_token,
893+
"prefill_token": prefill_token,
894+
"decode_logprob": decode_logprob,
895+
"prefill_logprob": prefill_logprob,
896+
"diff": diff,
897+
"prompt_text": prompt[:100],
898+
"prefix_text": prefix_prompt[:100],
899+
"decode_all_tokens": token_ids,
900+
"decode_all_logprobs": decode_logprobs.tolist(),
901+
}
902+
)
903+
print(f" [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}")
904+
else:
905+
print(f" [Token {token_idx}] ✓ Match (bitwise equal)")
906+
907+
# Print summary
908+
print(f"\n{'=' * 80}")
909+
if failed_comparisons:
910+
print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected")
911+
print(f"{'=' * 80}")
912+
913+
# Group failures by prompt for better readability
914+
failures_by_prompt: dict[int, list[dict]] = {}
915+
for fail in failed_comparisons:
916+
pid = fail["prompt_idx"]
917+
if pid not in failures_by_prompt:
918+
failures_by_prompt[pid] = []
919+
failures_by_prompt[pid].append(fail)
920+
921+
for prompt_idx, failures in failures_by_prompt.items():
922+
print(f"\n{'=' * 80}")
923+
print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...")
924+
print(f"{'=' * 80}")
925+
print(f"Total failures for this prompt: {len(failures)}")
926+
927+
# Show where mismatches occur (which token positions)
928+
mismatch_positions = [f["token_idx"] for f in failures]
929+
print(f"Mismatch at token positions: {mismatch_positions}")
930+
931+
# Show first few failures in detail
932+
for i, fail in enumerate(failures[:5]): # Show first 5 failures per prompt
933+
print(f"\n [Failure {i + 1}] Token position {fail['token_idx']}:")
934+
print(f" Reason: {fail['reason']}")
935+
print(f" Prefix text: '{fail['prefix_text']}...'")
936+
print(
937+
f" Decode: token={fail['decode_token']}, "
938+
f"logprob={fail['decode_logprob']:.10f}"
939+
)
940+
print(
941+
f" Prefill: token={fail['prefill_token']}, "
942+
f"logprob={fail['prefill_logprob']:.10f}"
943+
)
944+
if "diff" in fail:
945+
print(f" Difference: {fail['diff']:.10e}")
946+
# Show in hex to see bitwise difference
947+
import struct
948+
949+
decode_hex = struct.pack("f", fail["decode_logprob"]).hex()
950+
prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex()
951+
print(f" Decode logprob (hex): 0x{decode_hex}")
952+
print(f" Prefill logprob (hex): 0x{prefill_hex}")
953+
954+
# If we have all tokens/logprobs, show the context
955+
if "decode_all_tokens" in fail and "decode_all_logprobs" in fail:
956+
token_idx = fail["token_idx"]
957+
all_tokens = fail["decode_all_tokens"]
958+
all_logprobs = fail["decode_all_logprobs"]
959+
960+
# Show context: 2 tokens before and after
961+
start = max(0, token_idx - 2)
962+
end = min(len(all_tokens), token_idx + 3)
963+
964+
print(f" Context (tokens {start} to {end - 1}):")
965+
for j in range(start, end):
966+
marker = " <-- MISMATCH" if j == token_idx else ""
967+
print(
968+
f" [{j}] token={all_tokens[j]}, "
969+
f"logprob={all_logprobs[j]:.8f}{marker}"
970+
)
971+
972+
if len(failures) > 5:
973+
print(f"\n ... and {len(failures) - 5} more failures for this prompt")
974+
975+
print(f"\n{'=' * 80}\n")
976+
977+
pytest.fail(
978+
f"Decode logprobs do not match prefill logprobs: "
979+
f"{len(failed_comparisons)} mismatches found."
980+
)
981+
else:
982+
print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!")
983+
print(f"{'=' * 80}\n")
984+
985+
710986
def LLM_with_max_seqs(
711987
model: str,
712988
max_num_seqs: int,

vllm/compilation/caching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import hashlib
55
import inspect
6+
import os
67
import pickle
78
from unittest.mock import patch
89

@@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str:
168169
)
169170
file_contents = {}
170171
for filepath in files:
171-
if filepath == "<string>":
172+
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
173+
if not os.path.isfile(filepath):
172174
file_contents[filepath] = ""
173175
else:
174176
with open(filepath) as f:

vllm/config/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from vllm.config.scheduler import RunnerType
2121
from vllm.config.utils import assert_hashable, config, getattr_iter
2222
from vllm.logger import init_logger
23+
from vllm.model_executor.layers.batch_invariant import (
24+
vllm_kernel_override_batch_invariant,
25+
)
2326
from vllm.platforms import current_platform
2427
from vllm.transformers_utils.config import (
2528
ConfigFormat,
@@ -420,10 +423,6 @@ def __post_init__(
420423
video_pruning_rate: float | None,
421424
) -> None:
422425
# Enable batch invariance settings if requested
423-
from vllm.model_executor.layers.batch_invariant import (
424-
vllm_kernel_override_batch_invariant,
425-
)
426-
427426
if vllm_kernel_override_batch_invariant():
428427
self.enforce_eager = True
429428

vllm/config/parallel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import vllm.envs as envs
1515
from vllm.config.utils import config
1616
from vllm.logger import init_logger
17+
from vllm.model_executor.layers.batch_invariant import (
18+
vllm_kernel_override_batch_invariant,
19+
)
1720
from vllm.platforms import current_platform
1821
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
1922

@@ -560,10 +563,6 @@ def use_ray(self) -> bool:
560563
def _verify_args(self) -> Self:
561564
# Lazy import to avoid circular import
562565
from vllm.executor.executor_base import ExecutorBase
563-
from vllm.model_executor.layers.batch_invariant import (
564-
vllm_kernel_override_batch_invariant,
565-
)
566-
from vllm.platforms import current_platform
567566

568567
# Enable batch invariance settings if requested
569568
if vllm_kernel_override_batch_invariant():

vllm/config/scheduler.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,12 @@ def compute_hash(self) -> str:
170170
return hash_str
171171

172172
def __post_init__(self, is_encoder_decoder: bool) -> None:
173-
from vllm.model_executor.layers.batch_invariant import (
174-
vllm_kernel_override_batch_invariant,
175-
)
176-
177173
if self.max_model_len is None:
178174
self.max_model_len = 8192
179175

180176
if self.max_num_seqs is None:
181177
self.max_num_seqs = 128
182178

183-
# Enable batch invariance settings if requested
184-
if vllm_kernel_override_batch_invariant():
185-
self.enable_chunked_prefill = False
186-
187179
if is_encoder_decoder:
188180
# Chunked prefill should be disabled for encoder-decoder models.
189181
self.disable_chunked_mm_input = True

vllm/engine/arg_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,8 +1698,7 @@ def _set_default_args(
16981698
# for non-pooling tasks.
16991699
# For pooling tasks the default is False
17001700
if model_config.runner_type != "pooling":
1701-
if self.enable_chunked_prefill is None:
1702-
self.enable_chunked_prefill = True
1701+
self.enable_chunked_prefill = True
17031702

17041703
# TODO: When prefix caching supports prompt embeds inputs, this
17051704
# check can be removed.

vllm/model_executor/layers/batch_invariant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,9 +755,9 @@ def vllm_kernel_override_batch_invariant():
755755
def override_envs_for_invariance():
756756
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
757757
supported_backends = [
758+
"FLASH_ATTN", # best supported backend
758759
"FLEX_ATTENTION",
759760
"FLASHINFER",
760-
"FLASH_ATTN",
761761
"FLASH_ATTN_MLA",
762762
"TRITON_MLA",
763763
# Not yet supported MLA backends

0 commit comments

Comments
 (0)