-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[RoPE] explicit factor > implicit factor in YaRN #40320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -220,15 +220,9 @@ def _compute_yarn_parameters( | |
| attention_factor = config.rope_scaling.get("attention_factor") | ||
| mscale = config.rope_scaling.get("mscale") | ||
| mscale_all_dim = config.rope_scaling.get("mscale_all_dim") | ||
|
|
||
| # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a | ||
| # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two | ||
| # values to compute the default attention scaling factor, instead of using `factor`. | ||
| if "original_max_position_embeddings" in config.rope_scaling: | ||
| original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] | ||
| factor = config.max_position_embeddings / original_max_position_embeddings | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this deletes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zucchini-nlp We should use explicit parameterization and warn when the defined parameters don't match as a whole (which is what this PR does 🤗 ) |
||
| else: | ||
| original_max_position_embeddings = config.max_position_embeddings | ||
| original_max_position_embeddings = ( | ||
| config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings | ||
| ) | ||
|
|
||
| def get_mscale(scale, mscale=1): | ||
| if scale <= 1: | ||
|
|
@@ -496,6 +490,33 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se | |
| f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" | ||
| ) | ||
|
|
||
| # Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context | ||
| # length, with `config.max_position_embeddings` corresponding to their post-yarn context length. | ||
| # However, for BC purposes, we allow the former to be unset. | ||
| original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings") | ||
| if original_max_position_embeddings is not None: | ||
| # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. | ||
| implicit_factor = config.max_position_embeddings / original_max_position_embeddings | ||
| if implicit_factor != factor: | ||
| logger.warning_once( | ||
| f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match " | ||
| "the ratio implicitly set by other parameters (implicit factor = " | ||
| "post-yarn context length / pre-yarn context length = " | ||
| "config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = " | ||
| f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " | ||
| "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config." | ||
|
Comment on lines
+500
to
+507
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we throw an error instead, if users explicitly set config values to be mismatching? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed in theory, but it can be breaking for some models on the Hub 😢 As such, I believe a warning is more adequate. |
||
| ) | ||
| # No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the | ||
| # pre-yarn or the post-yarn context length? | ||
| # BC: we assume it is the pre-yarn context length. | ||
| else: | ||
| logger.warning_once( | ||
| "config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will " | ||
| "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect " | ||
| "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * " | ||
| "factor) -- we recommend updating both fields for optimal downstream model usage." | ||
| ) | ||
|
|
||
|
|
||
| def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): | ||
| rope_scaling = config.rope_scaling | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here the implicit factor is taking precedence, which shouldn't happen
(validation is added below, in the validation function)