-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[RoPE] run RoPE tests when the model uses RoPE #40630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7f85b69
3ff9aad
285c1b5
8c51fb5
9330b63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |
| import pytest | ||
| from parameterized import parameterized | ||
|
|
||
| from transformers import set_seed | ||
| from transformers import PretrainedConfig, set_seed | ||
| from transformers.testing_utils import ( | ||
| is_flaky, | ||
| require_flash_attn, | ||
|
|
@@ -230,7 +230,6 @@ class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM | |
| test_pruning = False | ||
| model_tester_class = None | ||
| all_model_classes = None | ||
| rotary_embedding_layer = None # Enables RoPE tests if set | ||
| pipeline_model_mapping = None | ||
|
|
||
| def setUp(self): | ||
|
|
@@ -319,21 +318,28 @@ def test_token_classification_model(self): | |
|
|
||
| @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) | ||
| def test_model_rope_scaling_from_config(self, scaling_type): | ||
| if self.rotary_embedding_layer is None: | ||
| self.skipTest("Rotary embedding layer not set") | ||
| """ | ||
| Tests that we can initialize a model with RoPE scaling in the config, that it can run a forward pass, and | ||
| that a few basic model output properties are honored. | ||
| """ | ||
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
|
||
| if not _config_supports_rope_scaling(config): | ||
| self.skipTest("This model does not support RoPE scaling") | ||
|
|
||
| short_input = ids_tensor([1, 10], config.vocab_size) | ||
| long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) | ||
|
|
||
| set_seed(42) # Fixed seed at init time so the two models get the same random weights | ||
| config.rope_scaling = {"rope_type": "default"} | ||
gante marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| original_model = self.model_tester_class.base_model_class(config) | ||
| original_model.to(torch_device) | ||
| original_model.eval() | ||
| original_short_output = original_model(short_input).last_hidden_state | ||
| original_long_output = original_model(long_input).last_hidden_state | ||
|
|
||
| set_seed(42) # Fixed seed at init time so the two models get the same random weights | ||
| config.rope_scaling = {"type": scaling_type, "factor": 10.0} | ||
| config.rope_scaling = {"rope_type": scaling_type, "factor": 10.0} | ||
| scaled_model = self.model_tester_class.base_model_class(config) | ||
| scaled_model.to(torch_device) | ||
| scaled_model.eval() | ||
|
|
@@ -350,10 +356,26 @@ def test_model_rope_scaling_from_config(self, scaling_type): | |
| # The output should be different for long inputs | ||
| self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) | ||
|
|
||
| def test_model_rope_scaling(self): | ||
| if self.rotary_embedding_layer is None: | ||
| self.skipTest("Rotary embedding layer not set") | ||
| def test_model_rope_scaling_frequencies(self): | ||
| """Tests the frequency properties of the different RoPE scaling types on the model RoPE layer.""" | ||
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
|
||
| if not _config_supports_rope_scaling(config): | ||
| self.skipTest("This model does not support RoPE scaling") | ||
|
|
||
| # Retrieves the RoPE layer class from the base model class. Uses `.named_modules()` to avoid hardcoding the | ||
| # named location of the RoPE layer class. | ||
| base_model = self.model_tester.base_model_class(config) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we might need to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Given that the tests are only run on decoder-only models for now, I'd rather leave as is (and upgrade when it's needed) 🤗 |
||
| possible_rope_attributes = [ | ||
| "rotary_emb", # most common case | ||
| "global_rotary_emb", | ||
| "local_rotary_emb", | ||
| ] | ||
| for name, module in base_model.named_modules(): | ||
| if any(potential_name in name for potential_name in possible_rope_attributes): | ||
| rope_class = type(module) | ||
| break | ||
|
|
||
| scaling_factor = 10 | ||
| short_input_length = 10 | ||
| long_input_length = int(config.max_position_embeddings * 1.5) | ||
|
|
@@ -368,16 +390,17 @@ def test_model_rope_scaling(self): | |
| position_ids_long = position_ids_long.unsqueeze(0) | ||
|
|
||
| # Sanity check original RoPE | ||
| original_rope = self.rotary_embedding_layer(config=config).to(torch_device) | ||
| config.rope_scaling = {"rope_type": "default"} | ||
| original_rope = rope_class(config=config).to(torch_device) | ||
| original_cos_short, original_sin_short = original_rope(x, position_ids_short) | ||
| original_cos_long, original_sin_long = original_rope(x, position_ids_long) | ||
| torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) | ||
| torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) | ||
|
|
||
| # Sanity check linear RoPE scaling | ||
| # New position "x" should match original position with index "x/scaling_factor" | ||
| config.rope_scaling = {"type": "linear", "factor": scaling_factor} | ||
| linear_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device) | ||
| config.rope_scaling = {"rope_type": "linear", "factor": scaling_factor} | ||
| linear_scaling_rope = rope_class(config=config).to(torch_device) | ||
| linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) | ||
| linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) | ||
| torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) | ||
|
|
@@ -390,8 +413,8 @@ def test_model_rope_scaling(self): | |
| # Sanity check Dynamic NTK RoPE scaling | ||
| # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase | ||
| # with scaling_factor (or that `inv_freq` decreases) | ||
| config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} | ||
| ntk_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device) | ||
| config.rope_scaling = {"rope_type": "dynamic", "factor": scaling_factor} | ||
| ntk_scaling_rope = rope_class(config=config).to(torch_device) | ||
| ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) | ||
| ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) | ||
| torch.testing.assert_close(ntk_cos_short, original_cos_short) | ||
|
|
@@ -404,8 +427,8 @@ def test_model_rope_scaling(self): | |
|
|
||
| # Sanity check Yarn RoPE scaling | ||
| # Scaling should be over the entire input | ||
| config.rope_scaling = {"type": "yarn", "factor": scaling_factor} | ||
| yarn_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device) | ||
| config.rope_scaling = {"rope_type": "yarn", "factor": scaling_factor} | ||
| yarn_scaling_rope = rope_class(config=config).to(torch_device) | ||
| yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) | ||
| yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) | ||
| torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) | ||
|
|
@@ -450,3 +473,15 @@ def test_flash_attn_2_equivalence(self): | |
| logits = outputs.hidden_states[-1] | ||
| logits_fa = outputs_fa.hidden_states[-1] | ||
| torch.testing.assert_close(logits_fa, logits, atol=3e-2, rtol=3e-2) | ||
|
|
||
|
|
||
| def _config_supports_rope_scaling(config: PretrainedConfig) -> bool: | ||
| """Returns whether a certain model config supports RoPE scaling parameterization.""" | ||
| # Has rope_scaling -> model was designed with rope scaling in mind | ||
| # Has rope_theta (and no rope_scaling) -> probably an older model, but should support rope scaling as well | ||
| main_config_has_rope = hasattr(config, "rope_scaling") or hasattr(config, "rope_theta") | ||
| sub_config_has_rope = any( | ||
| hasattr(config[sub_config], "rope_scaling") or hasattr(config[sub_config], "rope_theta") | ||
| for sub_config in config.sub_configs.keys() | ||
| ) | ||
| return main_config_has_rope or sub_config_has_rope | ||
Uh oh!
There was an error while loading. Please reload this page.