Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from functools import cache
from importlib.util import find_spec
from typing import Callable

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

logger = init_logger(__name__)


# common functions
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -65,6 +71,23 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)


@cache
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
if current_platform.is_cuda():
return apply_rotary_emb

if current_platform.is_rocm():
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
return apply_rotary
else:
logger.warning(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings.")

return apply_rotary_emb_torch


# yarn functions
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(num_rotations: int,
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -63,7 +65,7 @@
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Expand Down Expand Up @@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor,

def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function()
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output


Expand Down