|
50 | 50 | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
51 | 51 | RowParallelLinear) |
52 | 52 | from vllm.model_executor.layers.quantization import QuantizationConfig |
| 53 | +from vllm.model_executor.layers.rotary_embedding.common import ( |
| 54 | + dispatch_rotary_emb_function) |
53 | 55 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
54 | 56 | from vllm.model_executor.models.module_mapping import MultiModelKeys |
55 | 57 | from vllm.multimodal import MULTIMODAL_REGISTRY |
|
63 | 65 | BaseProcessingInfo, PromptReplacement, |
64 | 66 | PromptUpdate) |
65 | 67 | from vllm.multimodal.profiling import BaseDummyInputsBuilder |
66 | | -from vllm.platforms import _Backend, current_platform |
| 68 | +from vllm.platforms import _Backend |
67 | 69 | from vllm.sequence import IntermediateTensors |
68 | 70 | from vllm.transformers_utils.tokenizer import AnyTokenizer |
69 | 71 | from vllm.utils.tensor_schema import TensorSchema, TensorShape |
@@ -272,13 +274,11 @@ def apply_rotary_emb_torch(x: torch.Tensor, |
272 | 274 |
|
273 | 275 | def apply_rotary_pos_emb_vision(t: torch.Tensor, |
274 | 276 | freqs: torch.Tensor) -> torch.Tensor: |
| 277 | + rotary_emb_function = dispatch_rotary_emb_function() |
275 | 278 | t_ = t.float() |
276 | 279 | cos = freqs.cos() |
277 | 280 | sin = freqs.sin() |
278 | | - apply_rotary_emb = apply_rotary_emb_torch |
279 | | - if current_platform.is_cuda(): |
280 | | - from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb |
281 | | - output = apply_rotary_emb(t_, cos, sin).type_as(t) |
| 281 | + output = rotary_emb_function(t_, cos, sin).type_as(t) |
282 | 282 | return output |
283 | 283 |
|
284 | 284 |
|
|
0 commit comments