Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/test_default_mm_loras.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def multimodal_server(): # noqa: F811
"--dtype",
"half",
"--max-model-len",
"12800",
"4096",
"--enforce-eager",
# lora config below
"--enable-lora",
Expand Down
12 changes: 11 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2142,8 +2142,18 @@ def _get_and_verify_max_len(
# If the user didn't specify `max_model_len`, then use that derived from
# the model config as a default value.
if max_model_len is None:
max_model_len = int(derived_max_model_len)
# For LongRoPE, default to original_max_position_embeddings to avoid
# performance degradation for shorter sequences
if rope_scaling is not None and rope_scaling["rope_type"] == "longrope":
max_model_len = int(
getattr(
hf_config, "original_max_position_embeddings", derived_max_model_len
)
)
Comment on lines +2145 to +2152
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this not belong above where derived_max_model_len is calculated?

It could come after the first rope_scaling check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to override the default max_model_len to original_max_position_embeddings while still allowing the user to manually specify a max_model_len which is larger than this value. I had originally written the code as modifying derived_max_model_len, but (as I read it) derived_max_model_len is really intended to be max_position_embeddings, because of the later check in line 2143:

https://github.com/vllm-project/vllm/pull/27431/files/684e475735f9594a784dd62aa14e0845884872db#diff-998c640befaf137b9af825f29f4e6e47d273caab1fd04093c97df24b18f5c417L2134

If max_model_len is manually set to a value exceeding derived_max_model_len, it will throw an error. Hopefully I'm understanding your question correctly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for explaining, I think I understand now!

else:
max_model_len = int(derived_max_model_len)
max_model_len = current_platform.check_max_model_len(max_model_len)

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
elif max_model_len > derived_max_model_len:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
import torch
import torch.nn as nn

from vllm.config import get_current_vllm_config
from vllm.logger import init_logger

from .common import rotate_neox

logger = init_logger(__name__)


class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Expand Down Expand Up @@ -43,6 +48,22 @@ def __init__(
self.short_factor = short_factor
self.long_factor = long_factor

# Force long factors if max_model_len (runtime max length) exceeds
# original_max_position_embeddings to prevent KV cache invalidation when
# sequences cross this threshold during generation
max_model_len = get_current_vllm_config().model_config.max_model_len
self.use_long_rope = max_model_len > original_max_position_embeddings
if self.use_long_rope:
logger.warning_once(
"Using LongRoPE scaling factors. This enables longer "
"contexts (%d tokens vs original %d tokens) at the cost of "
"some performance degradation for shorter sequences. If "
"this is not desired, set `max_model_len` to be at most %d.",
max_position_embeddings,
original_max_position_embeddings,
original_max_position_embeddings,
)

scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
Expand Down Expand Up @@ -112,15 +133,12 @@ def forward(
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)

k = self.original_max_position_embeddings
long_prompt_offset = (
torch.any(positions > k).float() * torch.full_like(positions, k)
).long()
idx = (
torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None
else positions
)
if self.use_long_rope:
k = self.original_max_position_embeddings
long_prompt_offset = torch.full_like(positions, k).long()
idx = torch.add(positions, long_prompt_offset)
else:
idx = positions
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

Expand Down