diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 861965106774..4f02c996bda1 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -2,15 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from functools import cache +from importlib.util import find_spec +from typing import Callable import torch +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb +logger = init_logger(__name__) + # common functions def rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, return apply_rotary_emb_torch(x, cos, sin, is_neox_style) +@cache +def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]: + if current_platform.is_cuda(): + return apply_rotary_emb + + if current_platform.is_rocm(): + if find_spec("flash_attn") is not None: + from flash_attn.ops.triton.rotary import apply_rotary + return apply_rotary + else: + logger.warning( + "flash_attn is not installed. Falling back to PyTorch " + "implementation for rotary embeddings.") + + return apply_rotary_emb_torch + + # yarn functions # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim(num_rotations: int, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f83a411459cc..38435a69444e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -50,6 +50,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.common import ( + dispatch_rotary_emb_function) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY @@ -63,7 +65,7 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend, current_platform +from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + rotary_emb_function = dispatch_rotary_emb_function() t_ = t.float() cos = freqs.cos() sin = freqs.sin() - apply_rotary_emb = apply_rotary_emb_torch - if current_platform.is_cuda(): - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb - output = apply_rotary_emb(t_, cos, sin).type_as(t) + output = rotary_emb_function(t_, cos, sin).type_as(t) return output