diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index f05fac2478d8..8fd038bca5d0 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): model=model, max_num_seqs=1, tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), - enforce_eager=True, gpu_memory_utilization=0.9, max_model_len=2048, dtype="bfloat16", @@ -998,7 +997,6 @@ def LLM_with_max_seqs( dtype="bfloat16", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), enable_prefix_caching=False, - enforce_eager=True, # Enable for MOE models # enable_expert_parallel=True, ) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 0234f228d700..65babd10a948 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -import functools import os from collections import namedtuple from collections.abc import Callable @@ -11,6 +10,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -737,11 +737,28 @@ def enable_batch_invariant_mode(): _batch_invariant_MODE = True _batch_invariant_LIB = torch.library.Library("aten", "IMPL") - _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") + + # Batch invariant matmuls are no longer needed after cublas overrides + if not is_torch_equal_or_newer("2.10.0.dev"): + if current_platform.is_device_capability(100): + # For PyTorch 2.9, B200 uses GEMV for bs=1 + # Requires https://github.com/pytorch/pytorch/pull/166735 + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") + else: + # Only source of batch invariance for Hopper is split-k, can disable through + # cuBLAS workspace config + _original_cublas_workspace_cfg = os.environ.get( + "CUBLAS_WORKSPACE_CONFIG", None + ) + _original_cublaslt_workspace_size = os.environ.get( + "CUBLASLT_WORKSPACE_SIZE", None + ) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1" + _batch_invariant_LIB.impl( "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" ) @@ -750,6 +767,7 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") # Also monkeypatch torch.bmm directly as a fallback + _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") _original_torch_bmm = torch.bmm torch.bmm = bmm_batch_invariant @@ -771,14 +789,6 @@ def enable_batch_invariant_mode(): ) torch.backends.cuda.preferred_blas_library(backend="cublaslt") - if not is_torch_equal_or_newer("2.10.0.dev"): - _original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None) - _original_cublaslt_workspace_size = os.environ.get( - "CUBLASLT_WORKSPACE_SIZE", None - ) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1" - def disable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm @@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) -@functools.cache def vllm_is_batch_invariant(): env_key = "VLLM_BATCH_INVARIANT" is_overridden = False diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index d7e4ea2e0388..0560fa15151c 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -19,6 +19,9 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None: return `True` if TRTLLM attention is forced to be used, return `False` if TRTLLM attention is forced to be not used. """ + if vllm_is_batch_invariant(): + logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant") + return False return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)