diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/test_default_mm_loras.py index 336bda81a9ef..818ee2644b54 100644 --- a/tests/entrypoints/openai/test_default_mm_loras.py +++ b/tests/entrypoints/openai/test_default_mm_loras.py @@ -29,7 +29,7 @@ def multimodal_server(): # noqa: F811 "--dtype", "half", "--max-model-len", - "12800", + "4096", "--enforce-eager", # lora config below "--enable-lora", diff --git a/vllm/config/model.py b/vllm/config/model.py index c335c5c25e9e..e3d158c1c60f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -2142,8 +2142,18 @@ def _get_and_verify_max_len( # If the user didn't specify `max_model_len`, then use that derived from # the model config as a default value. if max_model_len is None: - max_model_len = int(derived_max_model_len) + # For LongRoPE, default to original_max_position_embeddings to avoid + # performance degradation for shorter sequences + if rope_scaling is not None and rope_scaling["rope_type"] == "longrope": + max_model_len = int( + getattr( + hf_config, "original_max_position_embeddings", derived_max_model_len + ) + ) + else: + max_model_len = int(derived_max_model_len) max_model_len = current_platform.check_max_model_len(max_model_len) + # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. elif max_model_len > derived_max_model_len: diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py index 2a42e3bd00ec..e58c9783479b 100644 --- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -5,8 +5,13 @@ import torch import torch.nn as nn +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger + from .common import rotate_neox +logger = init_logger(__name__) + class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): """Phi3 family of models scaled rotary embedding. @@ -43,6 +48,22 @@ def __init__( self.short_factor = short_factor self.long_factor = long_factor + # Force long factors if max_model_len (runtime max length) exceeds + # original_max_position_embeddings to prevent KV cache invalidation when + # sequences cross this threshold during generation + max_model_len = get_current_vllm_config().model_config.max_model_len + self.use_long_rope = max_model_len > original_max_position_embeddings + if self.use_long_rope: + logger.warning_once( + "Using LongRoPE scaling factors. This enables longer " + "contexts (%d tokens vs original %d tokens) at the cost of " + "some performance degradation for shorter sequences. If " + "this is not desired, set `max_model_len` to be at most %d.", + max_position_embeddings, + original_max_position_embeddings, + original_max_position_embeddings, + ) + scale = self.max_position_embeddings / self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 @@ -112,15 +133,12 @@ def forward( query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) - k = self.original_max_position_embeddings - long_prompt_offset = ( - torch.any(positions > k).float() * torch.full_like(positions, k) - ).long() - idx = ( - torch.add(positions, long_prompt_offset) - if long_prompt_offset is not None - else positions - ) + if self.use_long_rope: + k = self.original_max_position_embeddings + long_prompt_offset = torch.full_like(positions, k).long() + idx = torch.add(positions, long_prompt_offset) + else: + idx = positions idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)