Skip to content

Commit

Permalink
[Model] Rename Phi3 rope scaling type (vllm-project#5595)
Browse files Browse the repository at this point in the history
  • Loading branch information
garg-amit authored and jimpang committed Jul 24, 2024
1 parent 6f0ebde commit 18fecfd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
5 changes: 4 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,10 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] != "su" \
and rope_scaling["type"] != "longrope":
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
Expand Down
19 changes: 12 additions & 7 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


class Phi3SuScaledRotaryEmbedding(nn.Module):
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
Expand All @@ -491,11 +491,12 @@ def __init__(

if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)

self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
Expand Down Expand Up @@ -608,7 +609,9 @@ def get_rope(
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
if scaling_type != "su":
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type != "su" and scaling_type != "longrope":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
Expand All @@ -633,7 +636,9 @@ def get_rope(
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "su":
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif scaling_type == "su" or scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
Expand All @@ -643,7 +648,7 @@ def get_rope(
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
Expand Down

0 comments on commit 18fecfd

Please sign in to comment.