From 4d93373d45da3a934486beaa1e2ee3f707e2ff71 Mon Sep 17 00:00:00 2001 From: evian Date: Mon, 28 Apr 2025 00:29:48 +0800 Subject: [PATCH] [Minor][Models] Pass partial_rotary_factor parameter to rope Signed-off-by: evian --- vllm/model_executor/models/llama.py | 7 ++++--- vllm/model_executor/models/persimmon.py | 3 ++- vllm/model_executor/models/stablelm.py | 8 ++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 04415dace4d9..38a18180e234 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -130,8 +130,8 @@ def __init__(self, self.head_dim = getattr(config, "head_dim", self.hidden_size // self.total_num_heads) # Phi models introduced a partial_rotary_factor parameter in the config - partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) - self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", + 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -163,11 +163,12 @@ def __init__(self, self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_dim, + rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=is_neox_style, + partial_rotary_factor=self.partial_rotary_factor, ) if hasattr(config, "interleaved_sliding_window"): diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 6e26658c9261..eacf02433b57 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -115,9 +115,10 @@ def __init__(self, self.rotary_emb = get_rope( self.head_dim, - rotary_dim=int(self.partial_rotary_factor * self.head_dim), + rotary_dim=self.head_dim, max_position=self.max_position_embeddings, base=self.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, ) self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 13460d1dfd49..1cbda7267e4c 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -104,9 +104,8 @@ def __init__(self, 1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings - rope_pct = getattr(config, "rope_pct", - getattr(config, "partial_rotary_factor", 1)) - self.rotary_ndims = int(self.head_dim * rope_pct) + self.partial_rotary_factor = getattr( + config, "rope_pct", getattr(config, "partial_rotary_factor", 1)) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim @@ -130,9 +129,10 @@ def __init__(self, prefix=f"{prefix}.o_proj") self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.rotary_ndims, + rotary_dim=self.head_dim, max_position=self.config.max_position_embeddings, base=self.config.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, ) self.attn = Attention(self.num_heads, self.head_dim,