@@ -104,9 +104,8 @@ def __init__(self,
104104 1 , self .total_num_key_value_heads // tp_size )
105105 self .head_dim = self .hidden_size // self .total_num_heads
106106 self .max_position_embeddings = config .max_position_embeddings
107- rope_pct = getattr (config , "rope_pct" ,
108- getattr (config , "partial_rotary_factor" , 1 ))
109- self .rotary_ndims = int (self .head_dim * rope_pct )
107+ self .partial_rotary_factor = getattr (
108+ config , "rope_pct" , getattr (config , "partial_rotary_factor" , 1 ))
110109 self .scaling = self .head_dim ** - 0.5
111110 self .q_size = self .num_heads * self .head_dim
112111 self .kv_size = self .num_key_value_heads * self .head_dim
@@ -130,9 +129,10 @@ def __init__(self,
130129 prefix = f"{ prefix } .o_proj" )
131130 self .rotary_emb = get_rope (
132131 self .head_dim ,
133- rotary_dim = self .rotary_ndims ,
132+ rotary_dim = self .head_dim ,
134133 max_position = self .config .max_position_embeddings ,
135134 base = self .config .rope_theta ,
135+ partial_rotary_factor = self .partial_rotary_factor ,
136136 )
137137 self .attn = Attention (self .num_heads ,
138138 self .head_dim ,
0 commit comments