11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import contextlib
4- import functools
54import os
65from collections import namedtuple
76from collections .abc import Callable
1110
1211import vllm .envs as envs
1312from vllm .logger import init_logger
13+ from vllm .platforms import current_platform
1414from vllm .triton_utils import tl , triton
1515from vllm .utils .torch_utils import is_torch_equal_or_newer
1616
@@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
737737
738738 _batch_invariant_MODE = True
739739 _batch_invariant_LIB = torch .library .Library ("aten" , "IMPL" )
740- _batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
741- _batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
742- _batch_invariant_LIB .impl ("aten::matmul" , matmul_batch_invariant , "CUDA" )
743- _batch_invariant_LIB .impl ("aten::bmm" , bmm_batch_invariant , "CUDA" )
744- _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
740+
741+ # Batch invariant matmuls are no longer needed after cublas overrides
742+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
743+ if current_platform .is_device_capability (100 ):
744+ # For PyTorch 2.9, B200 uses GEMV for bs=1
745+ # Requires https://github.com/pytorch/pytorch/pull/166735
746+ _batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
747+ _batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
748+ _batch_invariant_LIB .impl ("aten::matmul" , matmul_batch_invariant , "CUDA" )
749+ _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
750+ else :
751+ # Only source of batch invariance for Hopper is split-k, can disable through
752+ # cuBLAS workspace config
753+ _original_cublas_workspace_cfg = os .environ .get (
754+ "CUBLAS_WORKSPACE_CONFIG" , None
755+ )
756+ _original_cublaslt_workspace_size = os .environ .get (
757+ "CUBLASLT_WORKSPACE_SIZE" , None
758+ )
759+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
760+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
761+
745762 _batch_invariant_LIB .impl (
746763 "aten::_log_softmax" , _log_softmax_batch_invariant , "CUDA"
747764 )
@@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
750767 _batch_invariant_LIB .impl ("aten::mean.dim" , mean_batch_invariant , "CUDA" )
751768
752769 # Also monkeypatch torch.bmm directly as a fallback
770+ _batch_invariant_LIB .impl ("aten::bmm" , bmm_batch_invariant , "CUDA" )
753771 _original_torch_bmm = torch .bmm
754772 torch .bmm = bmm_batch_invariant
755773
@@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
771789 )
772790 torch .backends .cuda .preferred_blas_library (backend = "cublaslt" )
773791
774- if not is_torch_equal_or_newer ("2.10.0.dev" ):
775- _original_cublas_workspace_cfg = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , None )
776- _original_cublaslt_workspace_size = os .environ .get (
777- "CUBLASLT_WORKSPACE_SIZE" , None
778- )
779- os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
780- os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
781-
782792
783793def disable_batch_invariant_mode ():
784794 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
@@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
847857 return AttentionBlockSize (block_m = 16 , block_n = 16 )
848858
849859
850- @functools .cache
851860def vllm_is_batch_invariant ():
852861 env_key = "VLLM_BATCH_INVARIANT"
853862 is_overridden = False
0 commit comments