@@ -220,15 +220,9 @@ def _compute_yarn_parameters(
220220 attention_factor = config .rope_scaling .get ("attention_factor" )
221221 mscale = config .rope_scaling .get ("mscale" )
222222 mscale_all_dim = config .rope_scaling .get ("mscale_all_dim" )
223-
224- # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
225- # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
226- # values to compute the default attention scaling factor, instead of using `factor`.
227- if "original_max_position_embeddings" in config .rope_scaling :
228- original_max_position_embeddings = config .rope_scaling ["original_max_position_embeddings" ]
229- factor = config .max_position_embeddings / original_max_position_embeddings
230- else :
231- original_max_position_embeddings = config .max_position_embeddings
223+ original_max_position_embeddings = (
224+ config .rope_scaling .get ("original_max_position_embeddings" ) or config .max_position_embeddings
225+ )
232226
233227 def get_mscale (scale , mscale = 1 ):
234228 if scale <= 1 :
@@ -496,6 +490,33 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
496490 f"(defaults to 32 if None) and beta_slow={ beta_slow } (defaults to 1 if None)"
497491 )
498492
493+ # Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context
494+ # length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
495+ # However, for BC purposes, we allow the former to be unset.
496+ original_max_position_embeddings = config .rope_scaling .get ("original_max_position_embeddings" )
497+ if original_max_position_embeddings is not None :
498+ # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
499+ implicit_factor = config .max_position_embeddings / original_max_position_embeddings
500+ if implicit_factor != factor :
501+ logger .warning_once (
502+ f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = { factor } ) does not match "
503+ "the ratio implicitly set by other parameters (implicit factor = "
504+ "post-yarn context length / pre-yarn context length = "
505+ "config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = "
506+ f"{ implicit_factor } ). Using the explicit factor ({ factor } ) in YaRN. This may cause unexpected "
507+ "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
508+ )
509+ # No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
510+ # pre-yarn or the post-yarn context length?
511+ # BC: we assume it is the pre-yarn context length.
512+ else :
513+ logger .warning_once (
514+ "config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
515+ "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
516+ "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
517+ "factor) -- we recommend updating both fields for optimal downstream model usage."
518+ )
519+
499520
500521def _validate_longrope_parameters (config : PretrainedConfig , ignore_keys : Optional [set ] = None ):
501522 rope_scaling = config .rope_scaling
0 commit comments