Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix load balancing loss func for mixtral #28256

Conversation

liangxuZhang
Copy link
Contributor

What does this PR do?

Fixes #28255

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker and @younesbelkada
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kalomaze
Copy link

kalomaze commented Dec 30, 2023

How does this differ from #28115 ?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for deep diving. As discussed here this is welcome, we indeed had a bug in the implementation

Let's try to help with shapes and use something explicit like this:

    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) # [batch_size X sequence_length, top_k]

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) # [batch_size X sequence_length, top_k, num_experts]

    tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # [top_k, num_experts]

    # Compute the average probability of routing to these experts
    router_prob_per_expert = torch.mean(routing_weights, dim=0) # [num_experts]

    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) / top_k
    return overall_loss * num_experts

@@ -107,15 +107,16 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
selected_experts = selected_experts.reshape(-1)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
expert_mask = expert_mask.reshape(-1,top_k, num_experts)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed if we just remove selected_experts = selected_experts.reshape(-1)

@@ -107,15 +107,16 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
selected_experts = selected_experts.reshape(-1)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
expert_mask = expert_mask.reshape(-1,top_k, num_experts)
expert_mask = torch.max(expert_mask, dim=-2).values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
expert_mask = torch.max(expert_mask, dim=-2).values

should be removed

@codybum
Copy link

codybum commented Jan 7, 2024

What is the impact of this issue on Mixtral training? Will this fix conceivability improve the quality of training? Is it likely that previous Mixtral trainings are not as good as they could be?

It seems like an important issue for those working with Mixtral that has been waiting on merge approval for a while.

@theblackcat102
Copy link
Contributor

theblackcat102 commented Jan 8, 2024

I personally finds the loss to be much lower with the new implementation. But I wasn't sure if it has to do with the (num_experts**2) instead of just N. I'm pretty sure this is an error on original mixtral side. So far still waiting for the training result on new implemented balance loss to finish. Deepspeed also has an implementation of top-2 which we might be able to reference.

@ArthurZucker
Copy link
Collaborator

#28255 has information that could help, I am down to merge this for the release planned this week, just the comments that need to be adressed cc @liangxuZhang do you need help to finish this?

@liangxuZhang
Copy link
Contributor Author

#28255 has information that could help, I am down to merge this for the release planned this week, just the comments that need to be adressed cc @liangxuZhang do you need help to finish this?

@ArthurZucker LGTM. The new implementation is correct and concise, and I've made a new commit. In #28255, maybe we can have a deep discuss whether to concatenate gate logits of all layers.

@ArthurZucker
Copy link
Collaborator

Alright! Pretty sure the math shows it's equivalent to compute on individual layers then sum vs concate and compute, but let's merge this for now !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Failing test seems unrelated let's just rebase on main

@liangxuZhang liangxuZhang force-pushed the fix_load_balancing_loss_func_for_mixtral branch from cdd6509 to 0fe0244 Compare January 10, 2024 10:41
@liangxuZhang
Copy link
Contributor Author

Thanks! Failing test seems unrelated let's just rebase on main

@ArthurZucker I've just rebase on the main branch, but I'm not sure if I'm doing it right. Please tell me what else I need to do

@bratao
Copy link

bratao commented Jan 10, 2024

@liangxuZhang @ArthurZucker opinions about #28403 ? It looks complementary to this PR

@ArthurZucker
Copy link
Collaborator

Something like git pull upstream main if the remote is upstream, the exotic CI was fixed on main! I'll merge without it

@ArthurZucker ArthurZucker merged commit e768616 into huggingface:main Jan 11, 2024
15 of 17 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks a lot @liangxuZhang for this fix! 🤗

@dancingpipi
Copy link
Contributor

great!

@cryoco
Copy link

cryoco commented Jan 12, 2024

Impressive work!

staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* Correct the implementation of auxiliary loss of mixtrtal

* correct the implementation of auxiliary loss of mixtrtal

* Implement a simpler calculation method

---------

Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
MadElf1337 pushed a commit to MadElf1337/transformers that referenced this pull request Jan 15, 2024
* Correct the implementation of auxiliary loss of mixtrtal

* correct the implementation of auxiliary loss of mixtrtal

* Implement a simpler calculation method

---------

Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
wgifford pushed a commit to wgifford/transformers that referenced this pull request Jan 21, 2024
* Correct the implementation of auxiliary loss of mixtrtal

* correct the implementation of auxiliary loss of mixtrtal

* Implement a simpler calculation method

---------

Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
AjayP13 pushed a commit to AjayP13/transformers that referenced this pull request Jan 22, 2024
* Correct the implementation of auxiliary loss of mixtrtal

* correct the implementation of auxiliary loss of mixtrtal

* Implement a simpler calculation method

---------

Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect implementation of auxiliary loss
8 participants