-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPI+jax=mpi4jax #452
MPI+jax=mpi4jax #452
Conversation
This prevents a nasty bug, where non rank-0, since they dont log, never need to execute the buffer of async operations
731086f
to
4f256c9
Compare
This PR is completed. I added some lines to the readme indicating that if one wants to use jax machines and SR with MPI one must install Scaling with this is very good, with approximately 75% increase in speed when doubling the number of processes at low process count. For posterior-sake: what this PR does is essentially use the Lastly, this PR also enables jax omnistaging when mpi4jax is installed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work @PhilipVinc, thanks a lot!!!
Use mpi4jax to jit through MPI calls in jax-jitted code.
This allows SR to work with jax machines in many node setups.
The dependency is optional: if jax is installed, but only 1 process is used, nothing is required. If more than 1 process is used with jax machines, at SR construction an error is thrown instructing the user to
pip install mpi4jax
.This complicates the test matrix, however, as we should be running MPI tests with and without mpi4jax, which I still have to do.
Missing from this PR are tests and maybe some informational message when someone first runs netket