Skip to content

Commit cff41e7

Browse files
bwastixuebwang-amd
authored andcommitted
[unrevert] Add batch invariant kernel override for FlashInfer backend [2/n] (vllm-project#26373)
Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent e1c5d2d commit cff41e7

File tree

4 files changed

+81
-35
lines changed

4 files changed

+81
-35
lines changed

csrc/moe/topk_softmax_kernels.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <c10/cuda/CUDAGuard.h>
2222
#include "../cuda_compat.h"
2323
#include "../cub_helpers.h"
24-
#include "../core/batch_invariant.hpp"
2524

2625
#define MAX(a, b) ((a) > (b) ? (a) : (b))
2726
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -406,8 +405,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
406405
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
407406
static constexpr int VPT = Constants::VPT;
408407
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
409-
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
410-
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
408+
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
411409
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
412410

413411
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);

tests/v1/generation/test_batch_invariance.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
7676
seed.
7777
- Keep max_tokens and max_model_len bounded for speed and memory use.
7878
"""
79-
random.seed(12345)
79+
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
80+
random.seed(seed)
8081

8182
# Allow overrides from environment (useful for CI tuning)
8283
# "facebook/opt-125m" is too small, doesn't reliably test determinism
8384
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
8485
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
85-
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
86-
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
86+
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
87+
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
88+
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
89+
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
8790

8891
# Keep GPU memory usage low to avoid startup allocation failures.
89-
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
90-
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
92+
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
93+
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
9194
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
9295

9396
# Sampling parameters: longer outputs with a more random-sounding
@@ -111,7 +114,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
111114
# Engine with bs=1 behavior
112115
llm_bs1 = LLM_with_max_seqs(
113116
model=model,
114-
max_num_seqs=1,
117+
max_num_seqs=max_batch_size,
115118
gpu_memory_utilization=gpu_mem_util,
116119
max_model_len=max_model_len,
117120
swap_space=swap_space_gb,
@@ -126,7 +129,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
126129
# Engine with larger batch limit (e.g., 64)
127130
llm_bsN = LLM_with_max_seqs(
128131
model=model,
129-
max_num_seqs=batch_size,
132+
max_num_seqs=max_batch_size,
130133
gpu_memory_utilization=gpu_mem_util,
131134
max_model_len=max_model_len,
132135
swap_space=swap_space_gb,
@@ -135,15 +138,16 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
135138
mismatches = 0
136139

137140
for trial in range(num_trials):
138-
# Create a batch of size `batch_size` and insert the needle at
141+
# Create a batch of size `max_batch_size` and insert the needle at
139142
# a random index
140143
prompts: list[str] = []
144+
batch_size = random.randint(max_batch_size // 2, max_batch_size)
141145
needle_pos = random.randint(0, batch_size - 1)
142146
for i in range(batch_size):
143147
if i == needle_pos:
144148
prompts.append(needle_prompt)
145149
else:
146-
prompts.append(_random_prompt())
150+
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
147151

148152
# Generate with the larger-batch engine
149153
outputs = llm_bsN.generate(prompts, sampling)
@@ -154,19 +158,20 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
154158
text = needle_output.outputs[0].text
155159

156160
if text != baseline_text:
161+
print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
157162
mismatches += 1
158163

159164
passes = num_trials - mismatches
160165
# Dump how many passed vs failed
161166
print(
162167
f"[determinism] total={num_trials}, passed={passes}, "
163-
f"failed={mismatches}, batch_size={batch_size}"
168+
f"failed={mismatches}, max_batch_size={max_batch_size}"
164169
)
165170

166171
if mismatches > 0:
167172
pytest.fail(
168173
f"Nondeterministic outputs detected: {mismatches} failed out "
169-
f"of {num_trials} trials (batch_size={batch_size})."
174+
f"of {num_trials} trials (max_batch_size={max_batch_size})."
170175
)
171176

172177
finally:
@@ -199,25 +204,28 @@ def _extract_step_logprobs(request_output):
199204
not torch.cuda.is_available(),
200205
reason="Requires CUDA to match production inference path.",
201206
)
202-
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
203-
# model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
207+
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
208+
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
209+
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
210+
os.environ["VLLM_ATTENTION_BACKEND"] = backend
211+
212+
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
213+
random.seed(seed)
204214
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
205215
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
206216

207217
# Force float32 to avoid precision-induced differences.
208218
llm = LLM(
209219
model=model_name,
210220
tensor_parallel_size=tp_size,
211-
enforce_eager=True, # helps reduce nondeterminism from some backends
221+
enforce_eager=True,
222+
enable_prefix_caching=False,
212223
)
213224

214-
prompts = [
215-
"The capital of France is",
216-
"The capital of Germany is",
217-
]
225+
prompts = [_random_prompt(10, 1024) for i in range(100)]
218226

219227
sp = SamplingParams(
220-
temperature=0.0,
228+
temperature=0.6,
221229
top_p=1.0,
222230
max_tokens=8,
223231
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
@@ -238,29 +246,29 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
238246
)
239247
bs1_logprobs_per_prompt.append(step_logprobs)
240248

241-
# BS=2: run prompts in a batch and collect logprobs per step for each
249+
# BS=N: run prompts in a batch and collect logprobs per step for each
242250
# prompt.
243251
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
244252
assert len(outs_batched) == len(prompts)
245-
bs2_logprobs_per_prompt = []
253+
bsN_logprobs_per_prompt = []
246254
for o in outs_batched:
247255
step_logprobs = _extract_step_logprobs(o)
248256
if step_logprobs is None:
249257
pytest.skip(
250258
"Logits are not available on RequestOutput; "
251259
"enable logprobs return to run this test."
252260
)
253-
bs2_logprobs_per_prompt.append(step_logprobs)
261+
bsN_logprobs_per_prompt.append(step_logprobs)
254262

255-
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
256-
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
257-
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)
263+
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
264+
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
265+
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)
258266
):
259-
assert len(logprobs_bs1) == len(logprobs_bs2), (
267+
assert len(logprobs_bs1) == len(logprobs_bsN), (
260268
f"Different number of generation steps for prompt index {i}: "
261-
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)"
269+
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
262270
)
263-
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
271+
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
264272
assert a.shape == b.shape, (
265273
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
266274
)
@@ -297,6 +305,7 @@ def LLM_with_max_seqs(
297305
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
298306
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
299307
enable_prefix_caching=False,
308+
enforce_eager=True,
300309
# Enable for MOE models
301310
# enable_expert_parallel=True,
302311
)

vllm/model_executor/layers/batch_invariant.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
import torch
1010

11+
import vllm.envs as envs
12+
from vllm.logger import init_logger
1113
from vllm.triton_utils import tl, triton
1214

15+
logger = init_logger(__name__)
16+
1317

1418
def _matmul_launch_metadata(
1519
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
@@ -562,5 +566,14 @@ def vllm_kernel_override_batch_invariant():
562566
def init_batch_invariance():
563567
# this will hit all the csrc overrides as well
564568
if vllm_kernel_override_batch_invariant():
565-
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
569+
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
570+
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
571+
if curr_attn_backend not in supported_backends:
572+
warning = (
573+
"Forcibly updating attention backend to"
574+
f" {supported_backends[0]} for batch_invariant. "
575+
f" Supported backends: {supported_backends}."
576+
)
577+
logger.warning_once(warning)
578+
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
566579
enable_batch_invariant_mode()

vllm/v1/attention/backends/flashinfer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
)
2626
from vllm.config import CUDAGraphMode, VllmConfig
2727
from vllm.logger import init_logger
28+
from vllm.model_executor.layers.batch_invariant import (
29+
vllm_kernel_override_batch_invariant,
30+
)
2831
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2932
QuantKey,
3033
kFp8StaticTensorSym,
@@ -50,6 +53,7 @@
5053
from vllm.v1.kv_cache_interface import AttentionSpec
5154

5255
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
56+
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
5357

5458
FP8_DTYPE = current_platform.fp8_dtype()
5559
FP4_DTYPE = torch.uint8
@@ -288,6 +292,15 @@ def __init__(
288292
self._prefill_wrapper = None # Wrapper for prefill/append
289293
self._decode_wrapper = None # Wrapper for decode (general shape)
290294

295+
if vllm_kernel_override_batch_invariant():
296+
self.decode_fixed_split_size = 2048
297+
self.prefill_fixed_split_size = 4096
298+
self.disable_split_kv = True
299+
else:
300+
self.decode_fixed_split_size = -1
301+
self.prefill_fixed_split_size = -1
302+
self.disable_split_kv = False
303+
291304
self.compilation_config = vllm_config.compilation_config
292305
max_num_pages_per_req = cdiv(
293306
self.model_config.max_model_len, self.kv_cache_spec.block_size
@@ -391,8 +404,11 @@ def __init__(
391404

392405
def _get_workspace_buffer(self):
393406
if self._workspace_buffer is None:
407+
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
408+
if vllm_kernel_override_batch_invariant():
409+
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
394410
self._workspace_buffer = torch.zeros(
395-
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device
411+
buffer_size, dtype=torch.uint8, device=self.device
396412
)
397413
return self._workspace_buffer
398414

@@ -669,6 +685,8 @@ def build(
669685
logits_soft_cap=self.logits_soft_cap,
670686
q_data_type=self.q_data_type,
671687
kv_data_type=self.kv_cache_dtype,
688+
fixed_split_size=self.prefill_fixed_split_size,
689+
disable_split_kv=self.disable_split_kv,
672690
)
673691
else:
674692
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
@@ -730,6 +748,8 @@ def build(
730748
logits_soft_cap=self.logits_soft_cap,
731749
q_data_type=self.q_data_type,
732750
kv_data_type=self.kv_cache_dtype,
751+
fixed_split_size=self.decode_fixed_split_size,
752+
disable_split_kv=self.disable_split_kv,
733753
)
734754
return attn_metadata
735755

@@ -1121,6 +1141,8 @@ def fast_plan_decode(
11211141
rope_scale: float | None = None,
11221142
rope_theta: float | None = None,
11231143
non_blocking: bool = True,
1144+
fixed_split_size: int = -1,
1145+
disable_split_kv: bool = False,
11241146
) -> None:
11251147
"""
11261148
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
@@ -1157,6 +1179,10 @@ def fast_plan_decode(
11571179
rope_scale,
11581180
rope_theta,
11591181
non_blocking,
1182+
None, # block_tables
1183+
None, # seq_lens
1184+
fixed_split_size,
1185+
disable_split_kv,
11601186
)
11611187
self.vllm_first_call = False
11621188
return
@@ -1222,8 +1248,8 @@ def fast_plan_decode(
12221248
head_dim,
12231249
False, # causal
12241250
window_left,
1225-
-1, # fixed_split_size
1226-
False, # disable_split_kv
1251+
fixed_split_size,
1252+
disable_split_kv,
12271253
)
12281254
except Exception as e:
12291255
raise RuntimeError(f"Error in tensor core plan: {e}") from e

0 commit comments

Comments
 (0)