Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,9 @@ def _compute_yarn_parameters(
attention_factor = config.rope_scaling.get("attention_factor")
mscale = config.rope_scaling.get("mscale")
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")

# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if "original_max_position_embeddings" in config.rope_scaling:
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
factor = config.max_position_embeddings / original_max_position_embeddings
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here the implicit factor is taking precedence, which shouldn't happen

(validation is added below, in the validation function)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this deletes factor and it seems to be still used below with attention_factor 👀

Copy link
Member Author

@gante gante Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp factor is defined a few lines above (L219). The line deleted here is a redefinition, where factor is implicitly derived from other parameters.

We should use explicit parameterization and warn when the defined parameters don't match as a whole (which is what this PR does 🤗 )

else:
original_max_position_embeddings = config.max_position_embeddings
original_max_position_embeddings = (
config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
)

def get_mscale(scale, mscale=1):
if scale <= 1:
Expand Down Expand Up @@ -496,6 +490,33 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)

# Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context
# length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
# However, for BC purposes, we allow the former to be unset.
original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
if original_max_position_embeddings is not None:
# Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
implicit_factor = config.max_position_embeddings / original_max_position_embeddings
if implicit_factor != factor:
logger.warning_once(
f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match "
"the ratio implicitly set by other parameters (implicit factor = "
"post-yarn context length / pre-yarn context length = "
"config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = "
f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
"behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
Comment on lines +500 to +507
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we throw an error instead, if users explicitly set config values to be mismatching?

Copy link
Member Author

@gante gante Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed in theory, but it can be breaking for some models on the Hub 😢 As such, I believe a warning is more adequate.

)
# No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
# pre-yarn or the post-yarn context length?
# BC: we assume it is the pre-yarn context length.
else:
logger.warning_once(
"config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
"**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
"config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
"factor) -- we recommend updating both fields for optimal downstream model usage."
)


def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
Expand Down
38 changes: 38 additions & 0 deletions tests/utils/test_modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,44 @@ def test_rope_validation(self):
self.assertEqual(len(logs.output), 1)
self.assertIn(model_specific_kwarg, logs.output[0])

def test_yarn_original_original_max_position_embeddings_validation(self):
"""Tests that models with no/bad `original_max_position_embeddings` raise a warning"""
config = LlamaConfig()

# good rope config: has a factor AND original_max_position_embeddings -> no warnings
rope_config = {
"rope_type": "yarn",
"factor": 2.0,
"original_max_position_embeddings": int(config.max_position_embeddings / 2.0),
}
config.rope_scaling = rope_config
with self.assertRaises(AssertionError): # confirm that no warnings are thrown
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
rope_config_validation(config)

# bad rope config, no `original_max_position_embeddings` -> warning
rope_config = {
"rope_type": "yarn",
"factor": 2.0,
}
config.rope_scaling = rope_config
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
rope_config_validation(config)
self.assertEqual(len(logs.output), 1)
self.assertIn("is unset", logs.output[0])

# bad rope config, bad implicit fator -> warning
rope_config = {
"rope_type": "yarn",
"factor": 2.0,
"original_max_position_embeddings": 1,
}
config.rope_scaling = rope_config
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
rope_config_validation(config)
self.assertEqual(len(logs.output), 1)
self.assertIn("implicit factor", logs.output[0])

def test_default_rope_numerically(self):
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
# multiple RoPE strategies will fail.
Expand Down