-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compile compatibilty for decoder-only models #32617
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a few comments, mostly about aligning with llama
Ran test_generate_compile_fullgraph and test_static_cache_matches_dynamic on all models + ran slow tests on models touched by this PR.
💛
@@ -899,9 +895,24 @@ def prepare_inputs_for_generation( | |||
|
|||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |||
if inputs_embeds is not None and cache_position[0] == 0: | |||
model_inputs = {"inputs_embeds": inputs_embeds} | |||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing #Copied from ...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really, bloom has alibi and needs 2D attention for that. So we can't expand it to 4D, and choose to append zeros to attn to make it static shape.
Updated with @gante comments and used the new RoPE modeling in all models. Ready for review! |
Failing tests are not related |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💎 thanks so much for this tedious work, well done 🥳
What is left is to make sure the compile tests pass !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it support compile ? (not seeing the supports_static_cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it does. You might have missed it :)
@@ -273,9 +380,29 @@ def rotate_half(x): | |||
return torch.cat((-x2, x1), dim=-1) | |||
|
|||
|
|||
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is potentially breaking no? (no more offset)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm right, lemme check this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update: just verified we don't need to slice anymore, because we apply rope directly on the curretn position. Prev we applied Rope for all positions up to the current and had to slice out cached positions
|
||
if past_key_value is not None: | ||
# Activate slicing cache only if the config has a value `sliding_windows` attribute | ||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 | ||
kv_seq_len = key_states.shape[-2] + cache_position[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't remember why we don't use cache_position[-1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the last position is the whole past kv length, which causes incorrect length in pre-fill or uncached generation. Maybe we should switch to simply past_length = cache_position[-1]
everywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for these very laborious changes 🙏
b3c91c0
to
1f328f0
Compare
@simonJJJ I added the new RoPE embedding for Qwen2-VL in this PR. Since I changes Qwen2, the changes were automatically propagated with |
@ArthurZucker @gante changed deprecation to v4.46 and added qwen2-VL. Ran the tests again to check everything is okey. Let me know if you have any comments |
@@ -870,7 +870,7 @@ def _update_causal_mask( | |||
# to infer the attention mask. | |||
|
|||
# cache_position must be valid here no matter which cache we use | |||
past_seen_tokens = cache_position[0] if past_key_values is not None else 0 | |||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as in Llama, using cache_position
is a dynamic control flow which is not supported currently by compile. The fullgraph-compile
test fails without this change
@zucchini-nlp happy with the changes, feel free to merge! (given that you mentioned that you re-ran the tests 💛 ) |
Yes, was exactly thinking to rebase main and re-ran tests one more time |
Test are passing, including slow. So, merging |
Can we update the tracker in #28981 |
* squash into one commit * add qwen2-vl for rope standardization * fix mistral compile * fix qwen2-vl * fix-copies
* squash into one commit * add qwen2-vl for rope standardization * fix mistral compile * fix qwen2-vl * fix-copies
* squash into one commit * add qwen2-vl for rope standardization * fix mistral compile * fix qwen2-vl * fix-copies
What does this PR do?
Recently we merged a few PRs deprecating old-style cache in all decoder-only models. This PR is a continuation of it, here we verify that all newly deprecated models can support static cache and are compatible with torch.compile. The main change is in RoPE to get rid of dynamic control flow
A few exception that cannot be supported yet: MoE models and some other with dynamic control flow like Phi3 or Chameleon.
Ran
test_generate_compile_fullgraph
andtest_static_cache_matches_dynamic
on all models + ran slow tests on models touched by this PR.In the next PR I can start deprecating old cache in encoder-decoder models starting from Bart and GPT models