-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixtral past_key_values and output_router_logits incompatible #30731
Comments
Hey! The generate function is not supposed to work for training. That is why we don't test past key values and output router logits. Though it's actually not that incompatible (you could want to look at the distribution of the router logits during generation). |
Hi, could I take this up? @ArthurZucker |
Sur feel free to open a PR! |
I am having trouble figuring out how to go about the
As you can see here, I think one approach is to retain the attention score of that particular token itself (the one that we are adding), and get rid of the other scores in the sequence. This is because, if we do retain the other ones, then
This approach is what I have come up with, I hope you all can verify the correctness of the same, and point out if I have made any mistakes. The PR is here. |
System Info
transformers==4.40.2
Python 3.11.8
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
It seems that this is the same underlying issue as in #29087 - I would expect
past_key_values
to work withoutput_router_logits
.So what happens?
all_router_logits
has the proper sequence length, thus inload_balancing_loss_func
thisnum_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
correctly evaluates the number of hidden layers.all_router_logits
has a sequence length of 1, but since the attention mask is still the whole sequence (from which thesequence_length
is inferred) the hidden layers evaluate to a small value or 0, leading to the same error as in Mixtral inference breaks whenoutput_router_logits=True
#29087Instead, I would like the
load_balancing_loss_func
to be able to deal with a case where thegate_logits
passed are of shape[batch_size X 1, num_experts]
instead of[batch_size X sequence_length, num_experts]
.The text was updated successfully, but these errors were encountered: