Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE models: add numerical sanity-check test for RoPE scaling #29808

Merged
merged 4 commits into from
Mar 28, 2024

Conversation

gante
Copy link
Member

@gante gante commented Mar 22, 2024

What does this PR do?

Fixes #29765

#29765 asks a pertinent question, where I had to look at the code to confirm the answer. This question should be checked automatically in a test instead -- confirm that RoPE scaling is working as intended.

@gante gante requested a review from amyeroberts March 22, 2024 12:17
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this test!

Just a few questions from my side:

  • Rather than generations, could we check the embedding values to see whether they've been rescaled as expected instead?
  • From the note it says it matches 'our initial rope scaling' - is this here for BC, or is it correct and just the initial feature?
  • Have the outputs been run and checked to compare the outputs if double scaling does happen?

tests/models/llama/test_modeling_llama.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_llama.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_llama.py Outdated Show resolved Hide resolved
@gante
Copy link
Member Author

gante commented Mar 27, 2024

Rather than generations, could we check the embedding values to see whether they've been rescaled as expected instead?

Much better 👀 and doesn't need to be a slow test. Going to rework the PR

@gante gante changed the title Llama: add hard rope scaling test RoPE models: add numerical sanity-check test for RoPE scaling Mar 27, 2024
@gante gante requested a review from amyeroberts March 27, 2024 17:55
@gante
Copy link
Member Author

gante commented Mar 27, 2024

@amyeroberts deleted the previous slow test, and added a numerical test to the embeddings instead as you suggested -- much much faster to test, and a more precise sanity check.

I've added a test on all RoPE-scaling compatible models. The test can't be abstracted into the mixin (for now), as there are a few variations while we are working on torch.compile/caching.

@@ -360,6 +366,96 @@ def test_phi_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling_from_config(self, scaling_type):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is not a new test, but rather a missing test. It is copy/paste from the same test in Falcon (Phi is mostly copied from Falcon)

@@ -438,6 +443,65 @@ def test_model_rope_scaling(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))

def test_model_rope_scaling(self):
Copy link
Member Author

@gante gante Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is the same on all models, with two variations:

  • Llama passes position_ids to RoPE, as opposed to an integer depicting the sequence length. This is due to its torch.compile rework.
  • GPTNeoX has base=config.rotary_emb_base, in RoPE initialization, all other models have base=config.rope_theta. This parameter is the same in GPTNeoX, but it has a different name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we can employ some Copied froms :)

I don't think there's an easy way for the llama test - but for GPTNeox remapping with with x->y should work

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - super clear and easy to follow tests ❤️

@@ -438,6 +443,65 @@ def test_model_rope_scaling(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))

def test_model_rope_scaling(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, we can employ some Copied froms :)

I don't think there's an easy way for the llama test - but for GPTNeox remapping with with x->y should work

@gante gante merged commit 441de62 into huggingface:main Mar 28, 2024
18 checks passed
@gante gante deleted the fix_29765 branch March 28, 2024 11:25
itazap pushed a commit that referenced this pull request May 14, 2024
* add hard rope scaling test

* make fixup

* quick rope scaling tests

* add copy statements
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Extra linear scaling in LlamaRotaryEmbedding classes
3 participants