-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PhiMoE #33363
PhiMoE #33363
Conversation
@ArthurZucker @gante can I please get a review? |
…-amit/transformers into gargamit/onboard_phi3_5_moe
… gargamit/onboard_phi3_5_moe
Hi, it seems to be a very important and awaited PR!:) Other frameworks are willing to integrate MoE too, like in litgpt Lightning-AI/litgpt#1686. |
We are very much willing to integrate it as well 🤗 just came back from the torch conf, was a bit OO because of it 😢 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Let's go with camel cased classes, If we want to be compile compatible we need to have a script conversion and use the formulation from gpt fast moe with a version implemented here: https://github.com/huggingface/transformers/pull/30793/files#diff-733ab0a772c69f78b1d8ed361e6ae1fda7243652887aed0bab5d3ecf07794c01R789
Lot's of stuff seems similar to phi3 so we can probably copy from it!
TLDR, overall the mixer needs to be properly documented and written to be more understandable! |
… gargamit/onboard_phi3_5_moe
@ArthurZucker Thanks for reviewing the PR. I’ve refactored the code according to your suggestions, and it’s ready for another look. Also, the failing test case appears to be unrelated to this PR. Please let me know if it needs to be addressed. |
Reviewing! 🤗 |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, the only thing needed to merge:
- The
Copied
from need a capital letter - The core part needs a tad bit more doc as I said, why do we need a specific gradient computation (had to go through the paper to see that indeed you need a special gradient approx)
- That part of the code is IMO less readable than the rest, but fine for now!
THanks and sorry for the late revies!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you no longer need this complicated structred! See the __init__
for Albert for example!
You need to define a __all__
in the modeling and config and that's it
return torch.cat((-x2, x1), dim=-1) | ||
|
||
|
||
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb | |
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb |
self.rotary_emb = PhimoeRotaryEmbedding( | ||
config=self.config, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO you can already put this outside the Attention layer, and remove the copied from mixtral to pass in the position embedding!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, moved it to the PhimoeModel
class
return attn_output, attn_weights, past_key_value | ||
|
||
|
||
# copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2 with Mixtral->Phimoe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2 with Mixtral->Phimoe | |
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralFlashAttention2 with Mixtral->Phimoe |
} | ||
|
||
|
||
# copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe | |
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralBlockSparseTop2MLP with Mixtral->Phimoe |
Returns: | ||
Tuple[torch.Tensor, torch.Tensor]: Multiplier and selected experts tensors. | ||
""" | ||
assert top_k == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also let's raise an error rather than an assert!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
routing_weights, selected_experts = sparsemixer( | ||
router_logits, | ||
top_k=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's hardcoded we can also just not put it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker I’ve removed top_k
from here and instead created it as a keyword argument.
config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True | ||
) | ||
|
||
# copied from transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# copied from transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward | |
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward |
@ArthurZucker Thanks for reviewing! I've addressed the comments and moved |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Thanks for integrating this new model 🔥
kv_seq_len = hidden_states.shape[-2] | ||
if past_key_values is not None: | ||
kv_seq_len += past_key_values.get_usable_length(kv_seq_len) | ||
position_embeddings = self.rotary_emb(hidden_states, seq_len=kv_seq_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty sure you should be using cache positions here! cache_position[0]!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the last nit!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I've updated it to cache_position[-1]+1
as cache_position[0]
would return 0 when the kv cache is empty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed! 🤗
Thanks everyone and @garg-amit for bearing with me! Congrats on the model release again 🤗 |
* onboard phimoe model * removed debug code * added unit tests * updated docs * formatted * fixed unit tests * fixed test case * fixed format * refactored code * fixed expected outputs in the integration tests * Added a warning msg * Addressed comments * Addressed comments * fixed test cases * added paper link * Addressed comments * Refactored PhimoeForCausalLM forward fn * Refactored PhimoeRotaryEmbedding class * fixed test cases * fixed testcase * fixed test case * Addressed comments * fixed test cases * fixed testcases * Used cache position instead to get the seq len
* onboard phimoe model * removed debug code * added unit tests * updated docs * formatted * fixed unit tests * fixed test case * fixed format * refactored code * fixed expected outputs in the integration tests * Added a warning msg * Addressed comments * Addressed comments * fixed test cases * added paper link * Addressed comments * Refactored PhimoeForCausalLM forward fn * Refactored PhimoeRotaryEmbedding class * fixed test cases * fixed testcase * fixed test case * Addressed comments * fixed test cases * fixed testcases * Used cache position instead to get the seq len
* onboard phimoe model * removed debug code * added unit tests * updated docs * formatted * fixed unit tests * fixed test case * fixed format * refactored code * fixed expected outputs in the integration tests * Added a warning msg * Addressed comments * Addressed comments * fixed test cases * added paper link * Addressed comments * Refactored PhimoeForCausalLM forward fn * Refactored PhimoeRotaryEmbedding class * fixed test cases * fixed testcase * fixed test case * Addressed comments * fixed test cases * fixed testcases * Used cache position instead to get the seq len
What does this PR do?
Integrates
PhiMoE
into transformers. https://huggingface.co/microsoft/Phi-3.5-MoE-instructBefore submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @gante