1010
1111import vllm .envs as envs
1212from vllm .logger import init_logger
13+ from vllm .platforms import current_platform
1314from vllm .triton_utils import tl , triton
1415from vllm .utils .torch_utils import is_torch_equal_or_newer
1516
@@ -736,11 +737,28 @@ def enable_batch_invariant_mode():
736737
737738 _batch_invariant_MODE = True
738739 _batch_invariant_LIB = torch .library .Library ("aten" , "IMPL" )
739- _batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
740- _batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
741- _batch_invariant_LIB .impl ("aten::matmul" , matmul_batch_invariant , "CUDA" )
742- _batch_invariant_LIB .impl ("aten::bmm" , bmm_batch_invariant , "CUDA" )
743- _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
740+
741+ # Batch invariant matmuls are no longer needed after cublas overrides
742+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
743+ if current_platform .is_device_capability (100 ):
744+ # For PyTorch 2.9, B200 uses GEMV for bs=1
745+ # Requires https://github.com/pytorch/pytorch/pull/166735
746+ _batch_invariant_LIB .impl ("aten::mm" , mm_batch_invariant , "CUDA" )
747+ _batch_invariant_LIB .impl ("aten::addmm" , addmm_batch_invariant , "CUDA" )
748+ _batch_invariant_LIB .impl ("aten::matmul" , matmul_batch_invariant , "CUDA" )
749+ _batch_invariant_LIB .impl ("aten::linear" , linear_batch_invariant , "CUDA" )
750+ else :
751+ # Only source of batch invariance for Hopper is split-k, can disable through
752+ # cuBLAS workspace config
753+ _original_cublas_workspace_cfg = os .environ .get (
754+ "CUBLAS_WORKSPACE_CONFIG" , None
755+ )
756+ _original_cublaslt_workspace_size = os .environ .get (
757+ "CUBLASLT_WORKSPACE_SIZE" , None
758+ )
759+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
760+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
761+
744762 _batch_invariant_LIB .impl (
745763 "aten::_log_softmax" , _log_softmax_batch_invariant , "CUDA"
746764 )
@@ -749,6 +767,7 @@ def enable_batch_invariant_mode():
749767 _batch_invariant_LIB .impl ("aten::mean.dim" , mean_batch_invariant , "CUDA" )
750768
751769 # Also monkeypatch torch.bmm directly as a fallback
770+ _batch_invariant_LIB .impl ("aten::bmm" , bmm_batch_invariant , "CUDA" )
752771 _original_torch_bmm = torch .bmm
753772 torch .bmm = bmm_batch_invariant
754773
@@ -770,14 +789,6 @@ def enable_batch_invariant_mode():
770789 )
771790 torch .backends .cuda .preferred_blas_library (backend = "cublaslt" )
772791
773- if not is_torch_equal_or_newer ("2.10.0.dev" ):
774- _original_cublas_workspace_cfg = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , None )
775- _original_cublaslt_workspace_size = os .environ .get (
776- "CUBLASLT_WORKSPACE_SIZE" , None
777- )
778- os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
779- os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
780-
781792
782793def disable_batch_invariant_mode ():
783794 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
0 commit comments