Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: RoPE refactor #32135

Merged
merged 30 commits into from
Jul 23, 2024
Merged

Llama: RoPE refactor #32135

merged 30 commits into from
Jul 23, 2024

Conversation

gante
Copy link
Member

@gante gante commented Jul 22, 2024

What does this PR do?

Same as #31999, but with llama being the only changed model.


Confirmed: slow tests are "passing" (same failures as main)
👉 RUN_SLOW=1 py.test -vv tests/models/llama/test_modeling_llama.py
👉 RUN_SLOW=1 py.test -vv tests/utils/test_cache_utils.py
👉 RUN_SLOW=1 py.test -vv tests/utils/test_modeling_rope_utils.py (new tests)


Throughput benchmarks: No changes vs previous main 💔

Comment on lines +83 to +84
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
# TODO(joao): add me back asap :)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#31999, which propagates the changes to all models, will fix this.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work consolidating the rope logic!

Mostly some small questions and nits. Main comment is about the testing for all the compute functions

src/transformers/modeling_rope_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_rope_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_rope_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_rope_utils.py Outdated Show resolved Hide resolved
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
`max_position_embeddings` to remain unchagned -- some methods, like 'longrope', require the original value to
determine which scaling to apply.
Expected contents:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all of the arguments expected, even if optional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, not at all :) the validation function exists to (among other things) detect incorrect parameter configurations

Comment on lines +323 to +304
"default": _compute_default_rope_parameters,
"linear": _compute_linear_scaling_rope_parameters,
"dynamic": _compute_dynamic_ntk_parameters,
"yarn": _compute_yarn_parameters,
"longrope": _compute_longrope_parameters,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these should be tested in a test rope utils module, including checks for taking rope_kwargs and config and their equivalence

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added "rope_kwargs and config and their equivalence" ✅

Numerical checks will be a todo for the post-release follow-up PR (#31999)

src/transformers/models/llama/configuration_llama.py Outdated Show resolved Hide resolved
Comment on lines 477 to 490
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_short, original_cos_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_short, original_sin_short)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_cos_long, original_cos_long)
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works and is consistent with the other checks above. We should really make sure to check the rescaling values with specific numerical values in tests for the compute methods as well. This tests tells us things have changed, but not whether the change is in the right direction or magnitude

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, but that is a test that requires some numerical diving. Given our release goals -- would it be okay for me to add a todo/open an issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as it's actually done, then yes ;)

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

should it be rope scaling rather than rope init? nit!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather go with init -- the default rope (i.e. not scaled) uses this path as well

Comment on lines 84 to 112
Dictionary containing the scaling configuration for the RoPE embeddings. IMPORTANT: RoPE scaling expects
`max_position_embeddings` to remain unchanged -- some methods, like 'longrope', require the original value
to determine which scaling to apply.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope'],
with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
`max_position_embeddings`.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`max_position_embeddings` * `factor`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this should leave enough freedom

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tho, the fact that we don't have a nested config makes it simpler, checks are run somwhere else so pretty much equivalent

Comment on lines 173 to 202

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to see that go aways!

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful - thanks for adding and iterating!



@require_torch
class RopeTest(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗

mig-mfreitas and others added 21 commits July 23, 2024 08:58
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts
Interpolation and Attention Scaling methods, improving upon existing
RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks
while enabling efficient extrapolation and transfer learning for
quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

 - LLaMA
 - Falcon
 - GPT-NeoX
 - Olmo
 - Persimmon
 - Phi
 - StableLM
 - OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both
short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071.

Co-authored-by: Miguel Almeida <miguel.pessanha.almeida@tecnico.ulisboa.pt>
Iterate on YaRN implementation for LLaMA and remove diff from remaining
models for increased PR modularity.

This commit includes the following changes:
- Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries
- Remove unnecessary attributes ('extrapolation_factor' and 'finetuned')
  from YaRN classes
- Inherit 'forward' method in YaRN classes from superclass
- Rename 'yarn' method to 'compute_yarn_scaling'
- Extend YaRN tests with further assertions
- Fix style inconsistencies

Co-authored-by: Miguel Monte e Freitas <miguelmontefreitas@tecnico.ulisboa.pt>
- Comply with the the tensor building logic introduced in huggingface#30743
- Add referencing to the optimized Attention Factor equation
- Remove Dynamic YaRN for a more agile deployment

Co-authored-by: mig-mfreitas <mig-mfreitas@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
gante and others added 7 commits July 23, 2024 08:58
@gante gante force-pushed the llama_rope_refactor branch from 1416972 to c824be0 Compare July 23, 2024 08:58
@gante
Copy link
Member Author

gante commented Jul 23, 2024

merged the yarn PR (percursor), now merging this one as soon as CI goes green

@amyeroberts
Copy link
Collaborator

Yarn PR is failing code quality checks on main. Could you make sure to rebase and then run make fix-copies etc here before merge?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants