diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index 5cc6fcfd9ac9..b864f9a31836 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -76,21 +76,18 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): seed. - Keep max_tokens and max_model_len bounded for speed and memory use. """ - seed = int(os.getenv("VLLM_TEST_SEED", "12345")) - random.seed(seed) + random.seed(12345) # Allow overrides from environment (useful for CI tuning) # "facebook/opt-125m" is too small, doesn't reliably test determinism model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) - max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) - min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) - max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048")) - assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle." + batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64")) + assert batch_size >= 2, "Batch size should be >= 2 to mix needle." # Keep GPU memory usage low to avoid startup allocation failures. - gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4")) - max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120")) + gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3")) + max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096")) swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) # Sampling parameters: longer outputs with a more random-sounding @@ -114,7 +111,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): # Engine with bs=1 behavior llm_bs1 = LLM_with_max_seqs( model=model, - max_num_seqs=128, + max_num_seqs=1, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, swap_space=swap_space_gb, @@ -129,7 +126,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): # Engine with larger batch limit (e.g., 64) llm_bsN = LLM_with_max_seqs( model=model, - max_num_seqs=128, + max_num_seqs=batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, swap_space=swap_space_gb, @@ -138,17 +135,15 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): mismatches = 0 for trial in range(num_trials): - # Create a batch of size `max_batch_size` and insert the needle at + # Create a batch of size `batch_size` and insert the needle at # a random index prompts: list[str] = [] - batch_size = random.randint(max_batch_size // 2, max_batch_size) needle_pos = random.randint(0, batch_size - 1) for i in range(batch_size): if i == needle_pos: prompts.append(needle_prompt) else: - prompts.append( - _random_prompt(min_random_prompt, max_random_prompt)) + prompts.append(_random_prompt()) # Generate with the larger-batch engine outputs = llm_bsN.generate(prompts, sampling) @@ -159,19 +154,17 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): text = needle_output.outputs[0].text if text != baseline_text: - print( - f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n") mismatches += 1 passes = num_trials - mismatches # Dump how many passed vs failed print(f"[determinism] total={num_trials}, passed={passes}, " - f"failed={mismatches}, max_batch_size={max_batch_size}") + f"failed={mismatches}, batch_size={batch_size}") if mismatches > 0: pytest.fail( f"Nondeterministic outputs detected: {mismatches} failed out " - f"of {num_trials} trials (max_batch_size={max_batch_size}).") + f"of {num_trials} trials (batch_size={batch_size}).") finally: # Ensure engines are shutdown to free GPU/VRAM across test sessions @@ -203,14 +196,9 @@ def _extract_step_logprobs(request_output): not torch.cuda.is_available(), reason="Requires CUDA to match production inference path.", ) -@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"]) -def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): +def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2(): - backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) - os.environ["VLLM_ATTENTION_BACKEND"] = backend - - seed = int(os.getenv("VLLM_TEST_SEED", "12345")) - random.seed(seed) + #model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m") model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) @@ -224,15 +212,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): prompts = [ "The capital of France is", "The capital of Germany is", - _random_prompt(10, 1024), - _random_prompt(10, 1024), - _random_prompt(10, 1024), - _random_prompt(10, 1024), - _random_prompt(10, 1024), ] sp = SamplingParams( - temperature=0.6, + temperature=0.0, top_p=1.0, max_tokens=8, # Seed shouldn't matter at temperature=0, but keeping it stable anyway. @@ -251,25 +234,25 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): "enable logprobs return to run this test.") bs1_logprobs_per_prompt.append(step_logprobs) - # BS=N: run prompts in a batch and collect logprobs per step for each + # BS=2: run prompts in a batch and collect logprobs per step for each # prompt. outs_batched = llm.generate(prompts, sp, use_tqdm=False) assert len(outs_batched) == len(prompts) - bsN_logprobs_per_prompt = [] + bs2_logprobs_per_prompt = [] for o in outs_batched: step_logprobs = _extract_step_logprobs(o) if step_logprobs is None: pytest.skip("Logits are not available on RequestOutput; " "enable logprobs return to run this test.") - bsN_logprobs_per_prompt.append(step_logprobs) + bs2_logprobs_per_prompt.append(step_logprobs) - # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. - for i, (logprobs_bs1, logprobs_bsN) in enumerate( - zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)): - assert len(logprobs_bs1) == len(logprobs_bsN), ( + # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs. + for i, (logprobs_bs1, logprobs_bs2) in enumerate( + zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)): + assert len(logprobs_bs1) == len(logprobs_bs2), ( f"Different number of generation steps for prompt index {i}: " - f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)") - for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)") + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)): assert a.shape == b.shape, ( f"Logits shape mismatch at prompt {i}, step {t}: " f"{a.shape} vs {b.shape}") diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 150c48c0e880..c025d509d862 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -8,12 +8,8 @@ import torch -import vllm.envs as envs -from vllm.logger import init_logger from vllm.triton_utils import tl, triton -logger = init_logger(__name__) - def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, args: dict[str, Any]) -> dict[str, Any]: @@ -561,12 +557,5 @@ def vllm_kernel_override_batch_invariant(): def init_batch_invariance(): # this will hit all the csrc overrides as well if vllm_kernel_override_batch_invariant(): - curr_attn_backend = envs.VLLM_ATTENTION_BACKEND - supported_backends = ["FLEX_ATTENTION", "FLASHINFER"] - if curr_attn_backend not in supported_backends: - warning = "Forcibly updating attention backend to" \ - f" {supported_backends[0]} for batch_invariant. " \ - f" Supported backends: {supported_backends}." - logger.warning_once(warning) - os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" enable_batch_invariant_mode() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 13f18d103b53..15a252734d4d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -20,8 +20,6 @@ AttentionType) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.platforms import current_platform @@ -44,7 +42,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -266,15 +263,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) - if vllm_kernel_override_batch_invariant(): - self.decode_fixed_split_size = 2048 - self.prefill_fixed_split_size = 4096 - self.disable_split_kv = True - else: - self.decode_fixed_split_size = -1 - self.prefill_fixed_split_size = -1 - self.disable_split_kv = False - self.compilation_config = vllm_config.compilation_config max_num_pages_per_req = cdiv(self.model_config.max_model_len, self.kv_cache_spec.block_size) @@ -368,12 +356,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def _get_workspace_buffer(self): if self._workspace_buffer is None: - buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE - if vllm_kernel_override_batch_invariant(): - buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT - self._workspace_buffer = torch.zeros(buffer_size, - dtype=torch.uint8, - device=self.device) + self._workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) return self._workspace_buffer def _get_prefill_wrapper(self): @@ -629,8 +615,6 @@ def build(self, logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( @@ -684,8 +668,6 @@ def build(self, logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.decode_fixed_split_size, - disable_split_kv=self.disable_split_kv, ) return attn_metadata @@ -1066,8 +1048,6 @@ def fast_plan_decode( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, non_blocking: bool = True, - fixed_split_size: int = -1, - disable_split_kv: bool = False, ) -> None: """ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for @@ -1105,10 +1085,6 @@ def fast_plan_decode( rope_scale, rope_theta, non_blocking, - None, # block_tables - None, # seq_lens - fixed_split_size, - disable_split_kv, ) self.vllm_first_call = False return @@ -1154,7 +1130,7 @@ def fast_plan_decode( qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") try: - # Make sure we pass exactly 18 arguments for tensor core version + # Make sure we pass exactly 15 arguments for tensor core version self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1171,9 +1147,6 @@ def fast_plan_decode( head_dim, head_dim, False, # causal - window_left, - fixed_split_size, - disable_split_kv, ) except Exception as e: raise RuntimeError(f"Error in tensor core plan: {e}") from e