Skip to content

Commit 6451294

Browse files
authored
[RoPE] explicit factor > implicit factor in YaRN (#40320)
explicit factor > implicit factor
1 parent 5a8ba87 commit 6451294

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed

src/transformers/modeling_rope_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,9 @@ def _compute_yarn_parameters(
220220
attention_factor = config.rope_scaling.get("attention_factor")
221221
mscale = config.rope_scaling.get("mscale")
222222
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
223-
224-
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
225-
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
226-
# values to compute the default attention scaling factor, instead of using `factor`.
227-
if "original_max_position_embeddings" in config.rope_scaling:
228-
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
229-
factor = config.max_position_embeddings / original_max_position_embeddings
230-
else:
231-
original_max_position_embeddings = config.max_position_embeddings
223+
original_max_position_embeddings = (
224+
config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
225+
)
232226

233227
def get_mscale(scale, mscale=1):
234228
if scale <= 1:
@@ -496,6 +490,33 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
496490
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
497491
)
498492

493+
# Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context
494+
# length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
495+
# However, for BC purposes, we allow the former to be unset.
496+
original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
497+
if original_max_position_embeddings is not None:
498+
# Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
499+
implicit_factor = config.max_position_embeddings / original_max_position_embeddings
500+
if implicit_factor != factor:
501+
logger.warning_once(
502+
f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match "
503+
"the ratio implicitly set by other parameters (implicit factor = "
504+
"post-yarn context length / pre-yarn context length = "
505+
"config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = "
506+
f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
507+
"behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
508+
)
509+
# No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
510+
# pre-yarn or the post-yarn context length?
511+
# BC: we assume it is the pre-yarn context length.
512+
else:
513+
logger.warning_once(
514+
"config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
515+
"**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
516+
"config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
517+
"factor) -- we recommend updating both fields for optimal downstream model usage."
518+
)
519+
499520

500521
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
501522
rope_scaling = config.rope_scaling

tests/utils/test_modeling_rope_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,44 @@ def test_rope_validation(self):
7777
self.assertEqual(len(logs.output), 1)
7878
self.assertIn(model_specific_kwarg, logs.output[0])
7979

80+
def test_yarn_original_original_max_position_embeddings_validation(self):
81+
"""Tests that models with no/bad `original_max_position_embeddings` raise a warning"""
82+
config = LlamaConfig()
83+
84+
# good rope config: has a factor AND original_max_position_embeddings -> no warnings
85+
rope_config = {
86+
"rope_type": "yarn",
87+
"factor": 2.0,
88+
"original_max_position_embeddings": int(config.max_position_embeddings / 2.0),
89+
}
90+
config.rope_scaling = rope_config
91+
with self.assertRaises(AssertionError): # confirm that no warnings are thrown
92+
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
93+
rope_config_validation(config)
94+
95+
# bad rope config, no `original_max_position_embeddings` -> warning
96+
rope_config = {
97+
"rope_type": "yarn",
98+
"factor": 2.0,
99+
}
100+
config.rope_scaling = rope_config
101+
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
102+
rope_config_validation(config)
103+
self.assertEqual(len(logs.output), 1)
104+
self.assertIn("is unset", logs.output[0])
105+
106+
# bad rope config, bad implicit fator -> warning
107+
rope_config = {
108+
"rope_type": "yarn",
109+
"factor": 2.0,
110+
"original_max_position_embeddings": 1,
111+
}
112+
config.rope_scaling = rope_config
113+
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
114+
rope_config_validation(config)
115+
self.assertEqual(len(logs.output), 1)
116+
self.assertIn("implicit factor", logs.output[0])
117+
80118
def test_default_rope_numerically(self):
81119
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
82120
# multiple RoPE strategies will fail.

0 commit comments

Comments
 (0)