Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
eb34c8d
add support methods to abstract
MatthewBonanni Oct 8, 2025
87edf38
remove is_attn_backend_supported
MatthewBonanni Oct 8, 2025
fc493ae
all backends are V1 now
MatthewBonanni Oct 8, 2025
9618979
use backend_to_class_str
MatthewBonanni Oct 8, 2025
8aeb461
add MLA backend support details
MatthewBonanni Oct 8, 2025
eb8426f
use backend_to_class_str
MatthewBonanni Oct 8, 2025
aba576c
add support details for standard attention backends
MatthewBonanni Oct 8, 2025
ff18a9a
update cuda logic
MatthewBonanni Oct 8, 2025
eaed800
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 8, 2025
9687c99
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 8, 2025
df49484
fix pre-commit
MatthewBonanni Oct 8, 2025
ff5ad7c
fix argument mismatch
MatthewBonanni Oct 8, 2025
712ae59
fix pre-commit
MatthewBonanni Oct 8, 2025
97e1a2c
use block size literals
MatthewBonanni Oct 8, 2025
8f86714
replace backend_name_to_enum with direct calls
MatthewBonanni Oct 9, 2025
50596d8
use DeviceCapability objects
MatthewBonanni Oct 9, 2025
03f6963
update max
MatthewBonanni Oct 9, 2025
3bee84e
Fix block size adjustment
MatthewBonanni Oct 9, 2025
a716f3a
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 10, 2025
15234bb
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 13, 2025
2433669
split priorities by capability, update flashinfer min capability
MatthewBonanni Oct 14, 2025
a3617d7
change to typing imports
MatthewBonanni Oct 15, 2025
81d1b7b
backends specify their required kv cache layout
MatthewBonanni Oct 15, 2025
adaf53b
flashinfer supports up to 12.1
MatthewBonanni Oct 15, 2025
d1f1362
is_mla is false in base class
MatthewBonanni Oct 15, 2025
abb8375
triton supports fp8
MatthewBonanni Oct 15, 2025
85d8719
use CacheDType
MatthewBonanni Oct 15, 2025
1ef0417
add todo
MatthewBonanni Oct 15, 2025
a2c902f
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 15, 2025
16f9373
is_quantized_kv_cache use CacheDType
MatthewBonanni Oct 15, 2025
8474a14
fix supports_sink
MatthewBonanni Oct 15, 2025
62e6290
fix priority list
MatthewBonanni Oct 15, 2025
22dd1b8
fix FA block sizes
MatthewBonanni Oct 16, 2025
4bf076d
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 16, 2025
121d442
fix import failure
MatthewBonanni Oct 16, 2025
963cc9f
fix import error
MatthewBonanni Oct 16, 2025
778cd98
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 20, 2025
de3f302
fix import error
MatthewBonanni Oct 20, 2025
bc10bee
fix import
MatthewBonanni Oct 20, 2025
05aab3e
fix type error
MatthewBonanni Oct 21, 2025
7936c47
add flashmla support test
MatthewBonanni Oct 21, 2025
4f0f955
clean up head size validation
MatthewBonanni Oct 21, 2025
feded36
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 21, 2025
d8b8043
use KVCacheLayoutType
MatthewBonanni Oct 21, 2025
a3ccbba
move selector layout change to same place as block size change
MatthewBonanni Oct 21, 2025
3285c2c
MLA only supports head size 576
MatthewBonanni Oct 21, 2025
6eab504
fix kv_cache_dtype support logic
MatthewBonanni Oct 21, 2025
5523dac
fix test
MatthewBonanni Oct 21, 2025
58fc888
skip FA MLA if test is run on hardware where it's not supported
MatthewBonanni Oct 21, 2025
17fd954
fix test
MatthewBonanni Oct 21, 2025
2b23712
fix pre-commit
MatthewBonanni Oct 21, 2025
68a63b7
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 22, 2025
fc1d3f3
fix head size
MatthewBonanni Oct 22, 2025
ecdef49
fix pre-commit
MatthewBonanni Oct 22, 2025
9008e56
flashinfer_mla only support blackwell (only uses TRTLLM kernels)
MatthewBonanni Oct 22, 2025
b756ceb
compute capability checks
MatthewBonanni Oct 22, 2025
afccece
remove reference to backend_name_to_enum
MatthewBonanni Oct 22, 2025
33cb1ef
fix default block size
MatthewBonanni Oct 22, 2025
3f5439e
improve logs
MatthewBonanni Oct 23, 2025
75fce85
fix block size support
MatthewBonanni Oct 23, 2025
ba51339
fix getting priority list
MatthewBonanni Oct 23, 2025
d49fbf9
remove redundant block size methods
MatthewBonanni Oct 24, 2025
dd31329
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 24, 2025
b18a193
fix import
MatthewBonanni Oct 24, 2025
0e0cb6d
raise error instead of implicitly changing backend
MatthewBonanni Oct 24, 2025
1a7b366
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 27, 2025
f147663
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 27, 2025
1eefe90
don't ignore block size
MatthewBonanni Oct 27, 2025
97bee04
move block_size update back to check_and_update_config
MatthewBonanni Oct 27, 2025
0812fac
fix import
MatthewBonanni Oct 27, 2025
ec39247
address missing case
MatthewBonanni Oct 27, 2025
e6497dd
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 28, 2025
860bfdb
fix flashmla_sparse support
MatthewBonanni Oct 28, 2025
df1cd64
fix hybrid models
MatthewBonanni Oct 28, 2025
758b3a5
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 29, 2025
01b43ff
return only mla or non-mla priorities
MatthewBonanni Oct 29, 2025
ee894ea
cleanup
MatthewBonanni Oct 29, 2025
842e89b
skip test on hopper
MatthewBonanni Oct 29, 2025
bd190e7
temp: apply fixes for test
MatthewBonanni Oct 29, 2025
5bf94f6
Revert "skip test on hopper"
MatthewBonanni Oct 29, 2025
7e34939
revert to old check_and_update_config block_size logic
MatthewBonanni Oct 29, 2025
3b1e92f
Revert "temp: apply fixes for test"
MatthewBonanni Oct 29, 2025
54dffe2
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 30, 2025
db6cc0f
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 30, 2025
d34eb77
add test_attention_selector to Blackwell Tests
MatthewBonanni Oct 31, 2025
48290ee
rename _Backend to AttentionBackendEnum, add class methods
MatthewBonanni Oct 31, 2025
1c71eab
get rid of get_min_compute_capability and get_max_compute_capability
MatthewBonanni Oct 31, 2025
6e9d1f1
fix pre-commit
MatthewBonanni Oct 31, 2025
d3cdda7
change methods to properties
MatthewBonanni Oct 31, 2025
925069c
device_capability not None
MatthewBonanni Oct 31, 2025
a0b56c5
query device_capability inside get_required_kv_cache_layout
MatthewBonanni Oct 31, 2025
fff453a
Update vllm/attention/backends/abstract.py
MatthewBonanni Oct 31, 2025
95aae78
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
530f356
class_path always None in decorator
MatthewBonanni Oct 31, 2025
933ee5f
type hint for value
MatthewBonanni Oct 31, 2025
255edc9
restore comment
MatthewBonanni Oct 31, 2025
6af36aa
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
c9d62f8
fix docs
MatthewBonanni Oct 31, 2025
93a0770
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
f6a5a32
add FLASHMLA_SPARSE to priority list
MatthewBonanni Oct 31, 2025
bc91050
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Oct 31, 2025
0435eca
fix test
MatthewBonanni Nov 2, 2025
a098d82
fix flashmla_sparse
MatthewBonanni Nov 2, 2025
d8215e0
Merge branch 'main' into backend_selection_refactor
MatthewBonanni Nov 2, 2025
4452f5f
fix pre-commit
MatthewBonanni Nov 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -883,11 +883,16 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
- vllm/v1/attention/backends/mla/flashinfer_mla.py
- vllm/platforms/cuda.py
- vllm/attention/selector.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
# Attention
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
- pytest -v -s tests/kernels/attention/test_attention_selector.py
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
Expand Down
31 changes: 16 additions & 15 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
Expand Down Expand Up @@ -104,7 +104,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:

# TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN:
if backend == AttentionBackendEnum.ROCM_ATTN:
# k/v as 1st dimention
# HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(
Expand All @@ -116,7 +116,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
Expand All @@ -128,7 +128,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.TRITON_ATTN:
elif backend == AttentionBackendEnum.TRITON_ATTN:
# k/v as 2nd dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
Expand All @@ -140,7 +140,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.FLASHINFER:
elif backend == AttentionBackendEnum.FLASHINFER:
kv_cache = torch.zeros(
num_blocks,
2,
Expand Down Expand Up @@ -244,8 +244,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
MODELS_FP4: list[tuple[str, type]] = []
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
BACKENDS_FP8: list[_Backend] = []
BACKENDS_FP4: list[_Backend] = []
BACKENDS_FP8: list[AttentionBackendEnum] = []
BACKENDS_FP4: list[AttentionBackendEnum] = []

if current_platform.is_cuda():
HEADS = [(64, 8), (40, 8)]
Expand All @@ -261,18 +261,18 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
TestAttentionNvfp4QuantPatternModel,
)
]
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
BACKENDS_FP4 = [_Backend.FLASHINFER]
BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]

elif current_platform.is_rocm():
HEADS = [(32, 8), (40, 8)]
MODELS_FP8 = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
BACKENDS = [
_Backend.ROCM_AITER_UNIFIED_ATTN,
_Backend.ROCM_ATTN,
_Backend.TRITON_ATTN,
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
AttentionBackendEnum.ROCM_ATTN,
AttentionBackendEnum.TRITON_ATTN,
]


Expand Down Expand Up @@ -302,18 +302,19 @@ def test_attention_quant_pattern(
custom_ops: str,
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
backend: AttentionBackendEnum,
dist_init,
):
"""Test AttentionStaticQuantPattern fusion pass"""
if backend == _Backend.FLASHINFER and (
if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")

custom_ops_list = custom_ops.split(",") if custom_ops else []

device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.manual_seed(42)

vllm_config = VllmConfig(
Expand Down Expand Up @@ -402,7 +403,7 @@ def test_attention_quant_pattern(

result_fused_1 = model_compiled(q, k, v)

if backend == _Backend.FLASHINFER:
if backend == AttentionBackendEnum.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded
Expand Down
24 changes: 12 additions & 12 deletions tests/compile/test_fusions_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest
import regex as re

from tests.v1.attention.utils import _Backend
from tests.v1.attention.utils import AttentionBackendEnum
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
Expand All @@ -24,7 +24,7 @@
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
backend: AttentionBackendEnum
attention_fusions: int
allreduce_fusions: int | None = None

Expand All @@ -39,14 +39,14 @@ class ModelBackendTestCase(NamedTuple):
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
Expand All @@ -56,7 +56,7 @@ class ModelBackendTestCase(NamedTuple):
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
Expand All @@ -67,7 +67,7 @@ class ModelBackendTestCase(NamedTuple):
ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
),
Expand All @@ -78,19 +78,19 @@ class ModelBackendTestCase(NamedTuple):
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
),
]
Expand All @@ -111,15 +111,15 @@ class ModelBackendTestCase(NamedTuple):
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if backend == _Backend.FLASHINFER and (
if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
Expand Down Expand Up @@ -203,7 +203,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
Expand Down
6 changes: 3 additions & 3 deletions tests/config/test_multimodal_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import pytest

from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MultiModalConfig


def test_mm_encoder_attn_backend_str_conversion():
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN


def test_mm_encoder_attn_backend_invalid():
Expand All @@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid():
def test_mm_encoder_attn_backend_hash_updates():
base_hash = MultiModalConfig().compute_hash()
overridden_hash = MultiModalConfig(
mm_encoder_attn_backend=_Backend.FLASH_ATTN
mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
).compute_hash()
assert base_hash != overridden_hash
75 changes: 45 additions & 30 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ def test_env(

elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability()
if use_mla:
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# and Blackwell GPUs (SM 10.x), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only
# (SM 10.x), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
Expand All @@ -141,58 +142,72 @@ def test_env(
if block_size != 128:
# CUTLASS_MLA only supports block_size == 128
pytest.skip("CUTLASS_MLA only supports block_size 128")
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
if capability[0] != 10:
pytest.skip("CUTLASS MLA is not supported on this platform")
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER_MLA":
if capability[0] != 10:
pytest.skip(
"FlashInfer MLA is not supported on this platform"
)
if block_size not in [32, 64]:
# FlashInfer MLA only supports block_size 32 or 64
pytest.skip(
"FlashInfer MLA only supports block_size 32 or 64"
)
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_dense_supported,
)
from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_dense_supported,
)

is_supported, _ = is_flashmla_dense_supported()
if not is_supported:
pytest.skip("FlashMLA not supported on this platform")
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
is_supported, _ = is_flashmla_dense_supported()
if not is_supported:
pytest.skip("FlashMLA not supported on this platform")
backend = get_attn_backend(
576,
torch.float16,
None,
block_size,
use_mla=use_mla,
)
expected = name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
from vllm.attention.utils.fa_utils import (
flash_attn_supports_mla,
)

if not flash_attn_supports_mla():
pytest.skip(
"FlashAttention MLA not supported on this platform"
)
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = "TRITON_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER":
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
64, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER"
assert backend.get_name() == expected
Expand Down
Loading