-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix load balancing loss func for mixtral #28256
Fix load balancing loss func for mixtral #28256
Conversation
How does this differ from #28115 ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for deep diving. As discussed here this is welcome, we indeed had a bug in the implementation
Let's try to help with shapes and use something explicit like this:
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1) # [batch_size X sequence_length, top_k]
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) # [batch_size X sequence_length, top_k, num_experts]
tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # [top_k, num_experts]
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0) # [num_experts]
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) / top_k
return overall_loss * num_experts
@@ -107,15 +107,16 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso | |||
selected_experts = selected_experts.reshape(-1) | |||
|
|||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) | |||
expert_mask = expert_mask.reshape(-1,top_k, num_experts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not needed if we just remove selected_experts = selected_experts.reshape(-1)
@@ -107,15 +107,16 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso | |||
selected_experts = selected_experts.reshape(-1) | |||
|
|||
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) | |||
expert_mask = expert_mask.reshape(-1,top_k, num_experts) | |||
expert_mask = torch.max(expert_mask, dim=-2).values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expert_mask = torch.max(expert_mask, dim=-2).values |
should be removed
What is the impact of this issue on Mixtral training? Will this fix conceivability improve the quality of training? Is it likely that previous Mixtral trainings are not as good as they could be? It seems like an important issue for those working with Mixtral that has been waiting on merge approval for a while. |
I personally finds the loss to be much lower with the new implementation. But I wasn't sure if it has to do with the (num_experts**2) instead of just N. I'm pretty sure this is an error on original mixtral side. So far still waiting for the training result on new implemented balance loss to finish. Deepspeed also has an implementation of top-2 which we might be able to reference. |
#28255 has information that could help, I am down to merge this for the release planned this week, just the comments that need to be adressed cc @liangxuZhang do you need help to finish this? |
@ArthurZucker LGTM. The new implementation is correct and concise, and I've made a new commit. In #28255, maybe we can have a deep discuss whether to concatenate gate logits of all layers. |
Alright! Pretty sure the math shows it's equivalent to compute on individual layers then sum vs concate and compute, but let's merge this for now ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Failing test seems unrelated let's just rebase on main
cdd6509
to
0fe0244
Compare
@ArthurZucker I've just rebase on the main branch, but I'm not sure if I'm doing it right. Please tell me what else I need to do |
@liangxuZhang @ArthurZucker opinions about #28403 ? It looks complementary to this PR |
Something like |
Thanks a lot @liangxuZhang for this fix! 🤗 |
great! |
Impressive work! |
* Correct the implementation of auxiliary loss of mixtrtal * correct the implementation of auxiliary loss of mixtrtal * Implement a simpler calculation method --------- Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
* Correct the implementation of auxiliary loss of mixtrtal * correct the implementation of auxiliary loss of mixtrtal * Implement a simpler calculation method --------- Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
* Correct the implementation of auxiliary loss of mixtrtal * correct the implementation of auxiliary loss of mixtrtal * Implement a simpler calculation method --------- Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
* Correct the implementation of auxiliary loss of mixtrtal * correct the implementation of auxiliary loss of mixtrtal * Implement a simpler calculation method --------- Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
What does this PR do?
Fixes #28255
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker and @younesbelkada
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.