From 025485b4778911a0252b39d448481ba17309745b Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Fri, 16 Aug 2024 05:37:43 -0400 Subject: [PATCH] Use head_dim if in config for RoPE (#32495) * use head_dim if in config for RoPE * typo * simplify with getattr --- src/transformers/modeling_rope_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 839adaecd0ca6a..c09664d688c3b1 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -58,7 +58,8 @@ def _compute_default_rope_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -143,7 +144,8 @@ def _compute_dynamic_ntk_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -185,7 +187,8 @@ def _compute_yarn_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] @@ -265,7 +268,8 @@ def _compute_longrope_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) long_factor = config.rope_scaling["long_factor"] short_factor = config.rope_scaling["short_factor"] factor = config.rope_scaling.get("factor") @@ -450,7 +454,8 @@ def _validate_longrope_parameters(config: PretrainedConfig): _check_received_keys(rope_type, received_keys, required_keys, optional_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):