Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch

from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/attention/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import pytest
import torch

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.backends.registry import _Backend
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

Expand Down
4 changes: 2 additions & 2 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import pytest
import torch

from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
Expand Down
3 changes: 2 additions & 1 deletion tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import pytest
import torch

from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import torch

from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/spec_decode/test_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import pytest
import torch

from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
Expand Down
3 changes: 2 additions & 1 deletion tests/v1/spec_decode/test_tree_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import torch

from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata

Expand Down
27 changes: 27 additions & 0 deletions vllm/attention/backends/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""

import enum


class _Backend(enum.Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I suppose this is where @ILikeIneine would like to plug in a different attn backend.
I think we can add that later, although this "registry" here is a very manual one: usually registries should expose a way to get added to the registry, here it's just a plain enum for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I see what you're saying. We could replace the enum with some other object, but the ergonomics of the enum are quite nice. Maybe we could construct an enum at import time?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checked out the other PR, I think it's fine as is and then we add the actual registration mechanism in the next PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a future P: I think it might be nice to make this enum platform specific?

FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
Expand All @@ -26,7 +27,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op

logger = init_logger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname

logger = init_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import zmq

from vllm import envs
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
Expand All @@ -32,7 +33,7 @@
from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down
5 changes: 3 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,9 @@ def get_vllm_port() -> Optional[int]:
# All possible options loaded dynamically from _Backend enum
"VLLM_ATTENTION_BACKEND":
env_with_choices("VLLM_ATTENTION_BACKEND", None,
lambda: list(__import__('vllm.platforms.interface', \
fromlist=['_Backend'])._Backend.__members__.keys())),
lambda: list(__import__(
'vllm.attention.backends.registry',
fromlist=['_Backend'])._Backend.__members__.keys())),

# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER":
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/dots_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils
Expand Down Expand Up @@ -38,7 +39,6 @@
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from einops import rearrange, repeat
from transformers import BatchFeature

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
Expand All @@ -54,7 +55,7 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
Expand All @@ -69,7 +70,6 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BaseModelOutputWithPooling)
from transformers.utils import torch_int

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand All @@ -39,7 +40,6 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
Expand All @@ -62,7 +63,6 @@
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor)

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
Expand All @@ -65,7 +66,6 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
Expand All @@ -66,7 +67,6 @@
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/siglip2navit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand All @@ -22,7 +23,6 @@
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import _Backend

from .vision import get_vit_attn_backend

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import torch
from transformers import PretrainedConfig

from vllm.attention.backends.registry import _Backend
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand Down
1 change: 0 additions & 1 deletion vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname, supports_xccl

from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum

logger = logging.getLogger(__name__)
Expand Down
7 changes: 5 additions & 2 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
from .interface import CpuArchEnum, Platform, PlatformEnum

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import VllmConfig
else:
_Backend = None
VllmConfig = None


Expand Down Expand Up @@ -90,10 +92,11 @@ def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"

@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool, use_sparse: bool) -> str:
from vllm.attention.backends.registry import _Backend
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
Expand Down
Loading