-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RoPE: relaxed rope validation #32182
Merged
+94
−37
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -526,6 +526,60 @@ def test_rope_class_retrocompatibility(self): | |
torch.testing.assert_close(old_cos_long, new_cos_long) | ||
torch.testing.assert_close(old_sin_long, new_sin_long) | ||
|
||
def test_model_loading_old_rope_configs(self): | ||
def _reinitialize_config(base_config, new_kwargs): | ||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation | ||
# steps. | ||
base_config_dict = base_config.to_dict() | ||
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) | ||
return new_config | ||
|
||
# from untouched config -> ✅ | ||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() | ||
original_model = LlamaForCausalLM(base_config).to(torch_device) | ||
original_model(**model_inputs) | ||
|
||
# from a config with the expected rope configuration -> ✅ | ||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) | ||
original_model = LlamaForCausalLM(config).to(torch_device) | ||
original_model(**model_inputs) | ||
|
||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC | ||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) | ||
original_model = LlamaForCausalLM(config).to(torch_device) | ||
original_model(**model_inputs) | ||
|
||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) | ||
config = _reinitialize_config( | ||
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} | ||
) | ||
self.assertTrue(config.rope_scaling["type"] == "linear") | ||
self.assertTrue(config.rope_scaling["rope_type"] == "linear") | ||
original_model = LlamaForCausalLM(config).to(torch_device) | ||
original_model(**model_inputs) | ||
|
||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning | ||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: | ||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) | ||
original_model = LlamaForCausalLM(config).to(torch_device) | ||
original_model(**model_inputs) | ||
self.assertEqual(len(logs.output), 1) | ||
self.assertIn("factor field", logs.output[0]) | ||
|
||
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning | ||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: | ||
config = _reinitialize_config( | ||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} | ||
) | ||
original_model = LlamaForCausalLM(config).to(torch_device) | ||
original_model(**model_inputs) | ||
self.assertEqual(len(logs.output), 1) | ||
self.assertIn("Unrecognized keys", logs.output[0]) | ||
|
||
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception | ||
with self.assertRaises(KeyError): | ||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are not testing the |
||
|
||
@require_flash_attn | ||
@require_torch_gpu | ||
@require_bitsandbytes | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the linked issues - it seems the problem was just with
rope_type
vstype
: does this mean older configs also have other keys?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts No other keys that we should support in our transformers code (hence the original exception). The only change was the renaming of (
type
->rope_type
), as suggested in the PR review.However, users may be using the base config class in their own custom projects. e.g. Phi3 used custom fields for rope scaling (which are no longer custom because we merged them). For that reason, this PR:
The only exceptions that persist are the ones mentioned at the top of this PR page, which would result in other logic errors down the line. I hope this clears out the issues we're seeing, while preventing as many future issues as possible from bad utilization.