-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug for distributed wrapper regarding to cross batch memory loss #639
Comments
Thanks for the code and explanation @zhaoyuac09!. I've found the distributed stuff to be quite tricky. I'm really busy for the next few days, so I'll have to look at your code a bit later. In the meantime, if you'd like, you can open a pull request with your code changes. |
Thank you @KevinMusgrave. I would be happy to create a pull request later after I finish more testing cases here. If later I have succeeded all testing cases, I will wrap up all changes and open a pull request.
I believe your repo is really nice and almost there for distributed training support. Thanks for the nice repo and let's make it even better. |
I am facing the same issue. @KevinMusgrave have you reviewed @zhaoyuac09 PR? |
@lolongcovas It's not passing the existing test. See my comment: #642 (comment) |
First of all, I really appreciated this repo. Thank you very much for the contribution! However, there are 2 functions will not work logically, in distributed.py for the loss and miner wrappers: gather_emb_and_ref and gather_enqueue_mask.
Let's take gather_enqueue_mask for example:
the all_gather function poped the rank, which will be different int on different GPUs, then torch cat the current enqueue_mask. Then the order Of the all gathered masks will not be guaranteed the same. When using cross batch memory losses, the embedding_memory will end up different on different GPUs, which I have already confirmed running some testing function.
Here I propose 2 changes to fix this issue:
and
The text was updated successfully, but these errors were encountered: