diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 99cf0c77339a..34c136980234 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -220,15 +220,9 @@ def _compute_yarn_parameters( attention_factor = config.rope_scaling.get("attention_factor") mscale = config.rope_scaling.get("mscale") mscale_all_dim = config.rope_scaling.get("mscale_all_dim") - - # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a - # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two - # values to compute the default attention scaling factor, instead of using `factor`. - if "original_max_position_embeddings" in config.rope_scaling: - original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] - factor = config.max_position_embeddings / original_max_position_embeddings - else: - original_max_position_embeddings = config.max_position_embeddings + original_max_position_embeddings = ( + config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings + ) def get_mscale(scale, mscale=1): if scale <= 1: @@ -496,6 +490,33 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" ) + # Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context + # length, with `config.max_position_embeddings` corresponding to their post-yarn context length. + # However, for BC purposes, we allow the former to be unset. + original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings") + if original_max_position_embeddings is not None: + # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. + implicit_factor = config.max_position_embeddings / original_max_position_embeddings + if implicit_factor != factor: + logger.warning_once( + f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match " + "the ratio implicitly set by other parameters (implicit factor = " + "post-yarn context length / pre-yarn context length = " + "config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = " + f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " + "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config." + ) + # No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the + # pre-yarn or the post-yarn context length? + # BC: we assume it is the pre-yarn context length. + else: + logger.warning_once( + "config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will " + "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect " + "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * " + "factor) -- we recommend updating both fields for optimal downstream model usage." + ) + def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling diff --git a/tests/utils/test_modeling_rope_utils.py b/tests/utils/test_modeling_rope_utils.py index 9427b3771b4a..7e000e0ff1a1 100644 --- a/tests/utils/test_modeling_rope_utils.py +++ b/tests/utils/test_modeling_rope_utils.py @@ -77,6 +77,44 @@ def test_rope_validation(self): self.assertEqual(len(logs.output), 1) self.assertIn(model_specific_kwarg, logs.output[0]) + def test_yarn_original_original_max_position_embeddings_validation(self): + """Tests that models with no/bad `original_max_position_embeddings` raise a warning""" + config = LlamaConfig() + + # good rope config: has a factor AND original_max_position_embeddings -> no warnings + rope_config = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": int(config.max_position_embeddings / 2.0), + } + config.rope_scaling = rope_config + with self.assertRaises(AssertionError): # confirm that no warnings are thrown + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + rope_config_validation(config) + + # bad rope config, no `original_max_position_embeddings` -> warning + rope_config = { + "rope_type": "yarn", + "factor": 2.0, + } + config.rope_scaling = rope_config + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + rope_config_validation(config) + self.assertEqual(len(logs.output), 1) + self.assertIn("is unset", logs.output[0]) + + # bad rope config, bad implicit fator -> warning + rope_config = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 1, + } + config.rope_scaling = rope_config + with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: + rope_config_validation(config) + self.assertEqual(len(logs.output), 1) + self.assertIn("implicit factor", logs.output[0]) + def test_default_rope_numerically(self): # Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then # multiple RoPE strategies will fail.