-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: RoPE refactor #32135
Llama: RoPE refactor #32135
Conversation
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon | ||
# TODO(joao): add me back asap :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#31999, which propagates the changes to all models, will fix this.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work consolidating the rope logic!
Mostly some small questions and nits. Main comment is about the testing for all the compute functions
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects | ||
`max_position_embeddings` to remain unchagned -- some methods, like 'longrope', require the original value to | ||
determine which scaling to apply. | ||
Expected contents: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all of the arguments expected, even if optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, not at all :) the validation function exists to (among other things) detect incorrect parameter configurations
"default": _compute_default_rope_parameters, | ||
"linear": _compute_linear_scaling_rope_parameters, | ||
"dynamic": _compute_dynamic_ntk_parameters, | ||
"yarn": _compute_yarn_parameters, | ||
"longrope": _compute_longrope_parameters, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these should be tested in a test rope utils module, including checks for taking rope_kwargs
and config
and their equivalence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added "rope_kwargs and config and their equivalence" ✅
Numerical checks will be a todo for the post-release follow-up PR (#31999)
config.rope_scaling = {"type": "yarn", "factor": scaling_factor} | ||
yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) | ||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) | ||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) | ||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) | ||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) | ||
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_cos_short, original_cos_short) | ||
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_sin_short, original_sin_short) | ||
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_cos_long, original_cos_long) | ||
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close(yarn_sin_long, original_sin_long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works and is consistent with the other checks above. We should really make sure to check the rescaling values with specific numerical values in tests for the compute methods as well. This tests tells us things have changed, but not whether the change is in the right direction or magnitude
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair, but that is a test that requires some numerical diving. Given our release goals -- would it be okay for me to add a todo/open an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as it's actually done, then yes ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
self.original_max_seq_len = config.max_position_embeddings | ||
|
||
self.config = config | ||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] | |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
should it be rope scaling rather than rope init? nit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather go with init -- the default rope (i.e. not scaled) uses this path as well
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects | ||
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value | ||
to determine which scaling to apply. | ||
Expected contents: | ||
`rope_type` (`str`): | ||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'], | ||
with 'default' being the original RoPE implementation. | ||
`factor` (`float`, *optional*): | ||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In | ||
most scaling types, a `factor` of x will enable the model to handle sequences of length x * | ||
`max_position_embeddings`. | ||
`attention_factor` (`float`, *optional*): | ||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention | ||
computation. If unspecified, it defaults to value recommended by the implementation, using the | ||
`factor` field to infer the suggested value. | ||
`beta_fast` (`float`, *optional*): | ||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear | ||
ramp function. If unspecified, it defaults to 32. | ||
`beta_slow` (`float`, *optional*): | ||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear | ||
ramp function. If unspecified, it defaults to 1. | ||
`short_factor` (`List[float]`, *optional*): | ||
Only used with 'longrope'. The scaling factor to be applied to short contexts (< | ||
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden | ||
size divided by the number of attention heads divided by 2 | ||
`long_factor` (`List[float]`, *optional*): | ||
Only used with 'longrope'. The scaling factor to be applied to short contexts (< | ||
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden | ||
size divided by the number of attention heads divided by 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok this should leave enough freedom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tho, the fact that we don't have a nested config makes it simpler, checks are run somwhere else so pretty much equivalent
|
||
def _rope_scaling_validation(self): | ||
""" | ||
Validate the `rope_scaling` configuration. | ||
""" | ||
if self.rope_scaling is None: | ||
return | ||
|
||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: | ||
raise ValueError( | ||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" | ||
) | ||
rope_scaling_type = self.rope_scaling.get("type", None) | ||
rope_scaling_factor = self.rope_scaling.get("factor", None) | ||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: | ||
raise ValueError( | ||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" | ||
) | ||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: | ||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice to see that go aways!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful - thanks for adding and iterating!
|
||
|
||
@require_torch | ||
class RopeTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
Iterate on YaRN implementation for LLaMA and remove diff from remaining models for increased PR modularity. This commit includes the following changes: - Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries - Remove unnecessary attributes ('extrapolation_factor' and 'finetuned') from YaRN classes - Inherit 'forward' method in YaRN classes from superclass - Rename 'yarn' method to 'compute_yarn_scaling' - Extend YaRN tests with further assertions - Fix style inconsistencies Co-authored-by: Miguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt>
- Comply with the the tensor building logic introduced in huggingface#30743 - Add referencing to the optimized Attention Factor equation - Remove Dynamic YaRN for a more agile deployment Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1416972
to
c824be0
Compare
merged the yarn PR (percursor), now merging this one as soon as CI goes green |
Yarn PR is failing code quality checks on main. Could you make sure to rebase and then run make fix-copies etc here before merge? |
What does this PR do?
Same as #31999, but with llama being the only changed model.
Confirmed: slow tests are "passing" (same failures as
main
)👉
RUN_SLOW=1 py.test -vv tests/models/llama/test_modeling_llama.py
👉
RUN_SLOW=1 py.test -vv tests/utils/test_cache_utils.py
👉
RUN_SLOW=1 py.test -vv tests/utils/test_modeling_rope_utils.py
(new tests)Throughput benchmarks: No changes vs previous
main
💔