Skip to content

Commit faedbb4

Browse files
authored
[Feature] Extend batch invariant torch.compile to B200 (#27856)
Signed-off-by: PaulZhang12 <paulzhan@fb.com>
1 parent 40db194 commit faedbb4

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
456456
model=model,
457457
max_num_seqs=1,
458458
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
459-
enforce_eager=True,
460459
gpu_memory_utilization=0.9,
461460
max_model_len=2048,
462461
dtype="bfloat16",
@@ -998,7 +997,6 @@ def LLM_with_max_seqs(
998997
dtype="bfloat16",
999998
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
1000999
enable_prefix_caching=False,
1001-
enforce_eager=True,
10021000
# Enable for MOE models
10031001
# enable_expert_parallel=True,
10041002
)

vllm/model_executor/layers/batch_invariant.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import contextlib
4-
import functools
54
import os
65
from collections import namedtuple
76
from collections.abc import Callable
@@ -11,6 +10,7 @@
1110

1211
import vllm.envs as envs
1312
from vllm.logger import init_logger
13+
from vllm.platforms import current_platform
1414
from vllm.triton_utils import tl, triton
1515
from vllm.utils.torch_utils import is_torch_equal_or_newer
1616

@@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
737737

738738
_batch_invariant_MODE = True
739739
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
740-
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
741-
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
742-
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
743-
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
744-
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
740+
741+
# Batch invariant matmuls are no longer needed after cublas overrides
742+
if not is_torch_equal_or_newer("2.10.0.dev"):
743+
if current_platform.is_device_capability(100):
744+
# For PyTorch 2.9, B200 uses GEMV for bs=1
745+
# Requires https://github.com/pytorch/pytorch/pull/166735
746+
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
747+
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
748+
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
749+
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
750+
else:
751+
# Only source of batch invariance for Hopper is split-k, can disable through
752+
# cuBLAS workspace config
753+
_original_cublas_workspace_cfg = os.environ.get(
754+
"CUBLAS_WORKSPACE_CONFIG", None
755+
)
756+
_original_cublaslt_workspace_size = os.environ.get(
757+
"CUBLASLT_WORKSPACE_SIZE", None
758+
)
759+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
760+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
761+
745762
_batch_invariant_LIB.impl(
746763
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
747764
)
@@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
750767
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
751768

752769
# Also monkeypatch torch.bmm directly as a fallback
770+
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
753771
_original_torch_bmm = torch.bmm
754772
torch.bmm = bmm_batch_invariant
755773

@@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
771789
)
772790
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
773791

774-
if not is_torch_equal_or_newer("2.10.0.dev"):
775-
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
776-
_original_cublaslt_workspace_size = os.environ.get(
777-
"CUBLASLT_WORKSPACE_SIZE", None
778-
)
779-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
780-
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
781-
782792

783793
def disable_batch_invariant_mode():
784794
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
@@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
847857
return AttentionBlockSize(block_m=16, block_n=16)
848858

849859

850-
@functools.cache
851860
def vllm_is_batch_invariant():
852861
env_key = "VLLM_BATCH_INVARIANT"
853862
is_overridden = False

vllm/utils/flashinfer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
import vllm.envs as envs
2121
from vllm.logger import init_logger
22+
from vllm.model_executor.layers.batch_invariant import (
23+
vllm_is_batch_invariant,
24+
)
2225
from vllm.platforms import current_platform
2326

2427
logger = init_logger(__name__)
@@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None:
222225
return `True` if TRTLLM attention is forced to be used,
223226
return `False` if TRTLLM attention is forced to be not used.
224227
"""
228+
if vllm_is_batch_invariant():
229+
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
230+
return False
225231
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
226232

227233

0 commit comments

Comments
 (0)