-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨 Llama: update rope scaling to match static cache changes #29143
Conversation
@@ -362,7 +362,6 @@ def test_save_load_fast_init_from_base(self): | |||
pass | |||
|
|||
@parameterized.expand([("linear",), ("dynamic",)]) | |||
@unittest.skip("TODO @gante fix this for Llama") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was fixed as a result of the changes in this PR :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧼 nice cleanup!
Main concern: BC, let's keep the cos_cache and sin_cache for 1 release and then we can directly open a PR on main to remove it!
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) | ||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) | ||
def forward(self, x, position_ids, seq_len=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am alright with this but it is breaking for any libs that rely on sin cached and cos cached. Same for the static cache PR!
Let's just add a mention that it will be removed next release and still compute cos and sin!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the cool part -- it calls super's forward, which in turn caches sin/cos (see here). BC is preserved 🙌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but we need a warning to deprecate !
Follow up is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I follow -- the warning is here. Or were you thinking of some other warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Had not seen this when I checked the diff
emb = torch.cat((freqs, freqs), dim=-1) | ||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) | ||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) | ||
cos, sin = super().forward(x, position_ids, seq_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a lot cleaner!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada also pointed out that the shape of the output of the rope layer is different from before. Thus this is a bit breaking. If so, let's add a big 🔴 on the PR to make sure we know that there are breaking changes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests all pass on PEFT end ! Thanks for the notice 💪
What does this PR do?
(see title :))
What's breaking? The shape of the returned sin/cos caches are changed (sin/cos for all positions -> sin/cos for the positions in
position_ids
). Note that this breaking change was also present in the static cache PR, for the main RoPE class (#27931).Review suggestion: