Skip to content

Commit 732016a

Browse files
DarkLight1337karan
authored andcommitted
Revert "Add batch invariant kernel override for FlashInfer backend [2/n]" (vllm-project#26220)
Signed-off-by: Karan Goel <3261985+karan@users.noreply.github.com>
1 parent c8db6bc commit 732016a

File tree

3 files changed

+29
-84
lines changed

3 files changed

+29
-84
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,18 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
7676
seed.
7777
- Keep max_tokens and max_model_len bounded for speed and memory use.
7878
"""
79-
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
80-
random.seed(seed)
79+
random.seed(12345)
8180

8281
# Allow overrides from environment (useful for CI tuning)
8382
# "facebook/opt-125m" is too small, doesn't reliably test determinism
8483
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
8584
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
86-
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
87-
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
88-
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
89-
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
85+
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
86+
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
9087

9188
# Keep GPU memory usage low to avoid startup allocation failures.
92-
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
93-
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
89+
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
90+
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
9491
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
9592

9693
# Sampling parameters: longer outputs with a more random-sounding
@@ -114,7 +111,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
114111
# Engine with bs=1 behavior
115112
llm_bs1 = LLM_with_max_seqs(
116113
model=model,
117-
max_num_seqs=128,
114+
max_num_seqs=1,
118115
gpu_memory_utilization=gpu_mem_util,
119116
max_model_len=max_model_len,
120117
swap_space=swap_space_gb,
@@ -129,7 +126,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
129126
# Engine with larger batch limit (e.g., 64)
130127
llm_bsN = LLM_with_max_seqs(
131128
model=model,
132-
max_num_seqs=128,
129+
max_num_seqs=batch_size,
133130
gpu_memory_utilization=gpu_mem_util,
134131
max_model_len=max_model_len,
135132
swap_space=swap_space_gb,
@@ -138,17 +135,15 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
138135
mismatches = 0
139136

140137
for trial in range(num_trials):
141-
# Create a batch of size `max_batch_size` and insert the needle at
138+
# Create a batch of size `batch_size` and insert the needle at
142139
# a random index
143140
prompts: list[str] = []
144-
batch_size = random.randint(max_batch_size // 2, max_batch_size)
145141
needle_pos = random.randint(0, batch_size - 1)
146142
for i in range(batch_size):
147143
if i == needle_pos:
148144
prompts.append(needle_prompt)
149145
else:
150-
prompts.append(
151-
_random_prompt(min_random_prompt, max_random_prompt))
146+
prompts.append(_random_prompt())
152147

153148
# Generate with the larger-batch engine
154149
outputs = llm_bsN.generate(prompts, sampling)
@@ -159,19 +154,17 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
159154
text = needle_output.outputs[0].text
160155

161156
if text != baseline_text:
162-
print(
163-
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
164157
mismatches += 1
165158

166159
passes = num_trials - mismatches
167160
# Dump how many passed vs failed
168161
print(f"[determinism] total={num_trials}, passed={passes}, "
169-
f"failed={mismatches}, max_batch_size={max_batch_size}")
162+
f"failed={mismatches}, batch_size={batch_size}")
170163

171164
if mismatches > 0:
172165
pytest.fail(
173166
f"Nondeterministic outputs detected: {mismatches} failed out "
174-
f"of {num_trials} trials (max_batch_size={max_batch_size}).")
167+
f"of {num_trials} trials (batch_size={batch_size}).")
175168

176169
finally:
177170
# Ensure engines are shutdown to free GPU/VRAM across test sessions
@@ -203,14 +196,9 @@ def _extract_step_logprobs(request_output):
203196
not torch.cuda.is_available(),
204197
reason="Requires CUDA to match production inference path.",
205198
)
206-
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
207-
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
199+
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
208200

209-
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
210-
os.environ["VLLM_ATTENTION_BACKEND"] = backend
211-
212-
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
213-
random.seed(seed)
201+
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
214202
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
215203
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
216204

@@ -224,15 +212,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
224212
prompts = [
225213
"The capital of France is",
226214
"The capital of Germany is",
227-
_random_prompt(10, 1024),
228-
_random_prompt(10, 1024),
229-
_random_prompt(10, 1024),
230-
_random_prompt(10, 1024),
231-
_random_prompt(10, 1024),
232215
]
233216

234217
sp = SamplingParams(
235-
temperature=0.6,
218+
temperature=0.0,
236219
top_p=1.0,
237220
max_tokens=8,
238221
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
@@ -251,25 +234,25 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
251234
"enable logprobs return to run this test.")
252235
bs1_logprobs_per_prompt.append(step_logprobs)
253236

254-
# BS=N: run prompts in a batch and collect logprobs per step for each
237+
# BS=2: run prompts in a batch and collect logprobs per step for each
255238
# prompt.
256239
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
257240
assert len(outs_batched) == len(prompts)
258-
bsN_logprobs_per_prompt = []
241+
bs2_logprobs_per_prompt = []
259242
for o in outs_batched:
260243
step_logprobs = _extract_step_logprobs(o)
261244
if step_logprobs is None:
262245
pytest.skip("Logits are not available on RequestOutput; "
263246
"enable logprobs return to run this test.")
264-
bsN_logprobs_per_prompt.append(step_logprobs)
247+
bs2_logprobs_per_prompt.append(step_logprobs)
265248

266-
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
267-
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
268-
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)):
269-
assert len(logprobs_bs1) == len(logprobs_bsN), (
249+
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
250+
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
251+
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
252+
assert len(logprobs_bs1) == len(logprobs_bs2), (
270253
f"Different number of generation steps for prompt index {i}: "
271-
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)")
272-
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
254+
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
255+
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
273256
assert a.shape == b.shape, (
274257
f"Logits shape mismatch at prompt {i}, step {t}: "
275258
f"{a.shape} vs {b.shape}")

vllm/model_executor/layers/batch_invariant.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,8 @@
88

99
import torch
1010

11-
import vllm.envs as envs
12-
from vllm.logger import init_logger
1311
from vllm.triton_utils import tl, triton
1412

15-
logger = init_logger(__name__)
16-
1713

1814
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
1915
args: dict[str, Any]) -> dict[str, Any]:
@@ -561,12 +557,5 @@ def vllm_kernel_override_batch_invariant():
561557
def init_batch_invariance():
562558
# this will hit all the csrc overrides as well
563559
if vllm_kernel_override_batch_invariant():
564-
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
565-
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
566-
if curr_attn_backend not in supported_backends:
567-
warning = "Forcibly updating attention backend to" \
568-
f" {supported_backends[0]} for batch_invariant. " \
569-
f" Supported backends: {supported_backends}."
570-
logger.warning_once(warning)
571-
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
560+
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
572561
enable_batch_invariant_mode()

vllm/v1/attention/backends/flashinfer.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
AttentionType)
2121
from vllm.config import CUDAGraphMode, VllmConfig
2222
from vllm.logger import init_logger
23-
from vllm.model_executor.layers.batch_invariant import (
24-
vllm_kernel_override_batch_invariant)
2523
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2624
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
2725
from vllm.platforms import current_platform
@@ -44,7 +42,6 @@
4442
from vllm.v1.kv_cache_interface import AttentionSpec
4543

4644
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
47-
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
4845

4946
FP8_DTYPE = current_platform.fp8_dtype()
5047
FP4_DTYPE = torch.uint8
@@ -266,15 +263,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
266263
self._prefill_wrapper = None # Wrapper for prefill/append
267264
self._decode_wrapper = None # Wrapper for decode (general shape)
268265

269-
if vllm_kernel_override_batch_invariant():
270-
self.decode_fixed_split_size = 2048
271-
self.prefill_fixed_split_size = 4096
272-
self.disable_split_kv = True
273-
else:
274-
self.decode_fixed_split_size = -1
275-
self.prefill_fixed_split_size = -1
276-
self.disable_split_kv = False
277-
278266
self.compilation_config = vllm_config.compilation_config
279267
max_num_pages_per_req = cdiv(self.model_config.max_model_len,
280268
self.kv_cache_spec.block_size)
@@ -368,12 +356,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
368356

369357
def _get_workspace_buffer(self):
370358
if self._workspace_buffer is None:
371-
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
372-
if vllm_kernel_override_batch_invariant():
373-
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
374-
self._workspace_buffer = torch.zeros(buffer_size,
375-
dtype=torch.uint8,
376-
device=self.device)
359+
self._workspace_buffer = torch.zeros(
360+
FLASHINFER_WORKSPACE_BUFFER_SIZE,
361+
dtype=torch.uint8,
362+
device=self.device)
377363
return self._workspace_buffer
378364

379365
def _get_prefill_wrapper(self):
@@ -629,8 +615,6 @@ def build(self,
629615
logits_soft_cap=self.logits_soft_cap,
630616
q_data_type=self.q_data_type,
631617
kv_data_type=self.kv_cache_dtype,
632-
fixed_split_size=self.prefill_fixed_split_size,
633-
disable_split_kv=self.disable_split_kv,
634618
)
635619
else:
636620
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
@@ -684,8 +668,6 @@ def build(self,
684668
logits_soft_cap=self.logits_soft_cap,
685669
q_data_type=self.q_data_type,
686670
kv_data_type=self.kv_cache_dtype,
687-
fixed_split_size=self.decode_fixed_split_size,
688-
disable_split_kv=self.disable_split_kv,
689671
)
690672
return attn_metadata
691673

@@ -1066,8 +1048,6 @@ def fast_plan_decode(
10661048
rope_scale: Optional[float] = None,
10671049
rope_theta: Optional[float] = None,
10681050
non_blocking: bool = True,
1069-
fixed_split_size: int = -1,
1070-
disable_split_kv: bool = False,
10711051
) -> None:
10721052
"""
10731053
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
@@ -1105,10 +1085,6 @@ def fast_plan_decode(
11051085
rope_scale,
11061086
rope_theta,
11071087
non_blocking,
1108-
None, # block_tables
1109-
None, # seq_lens
1110-
fixed_split_size,
1111-
disable_split_kv,
11121088
)
11131089
self.vllm_first_call = False
11141090
return
@@ -1154,7 +1130,7 @@ def fast_plan_decode(
11541130
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
11551131

11561132
try:
1157-
# Make sure we pass exactly 18 arguments for tensor core version
1133+
# Make sure we pass exactly 15 arguments for tensor core version
11581134
self._plan_info = self._cached_module.plan(
11591135
self._float_workspace_buffer,
11601136
self._int_workspace_buffer,
@@ -1171,9 +1147,6 @@ def fast_plan_decode(
11711147
head_dim,
11721148
head_dim,
11731149
False, # causal
1174-
window_left,
1175-
fixed_split_size,
1176-
disable_split_kv,
11771150
)
11781151
except Exception as e:
11791152
raise RuntimeError(f"Error in tensor core plan: {e}") from e

0 commit comments

Comments
 (0)