Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPI+jax=mpi4jax #452

Merged
merged 28 commits into from
Sep 16, 2020
Merged

MPI+jax=mpi4jax #452

merged 28 commits into from
Sep 16, 2020

Conversation

PhilipVinc
Copy link
Member

Use mpi4jax to jit through MPI calls in jax-jitted code.
This allows SR to work with jax machines in many node setups.

The dependency is optional: if jax is installed, but only 1 process is used, nothing is required. If more than 1 process is used with jax machines, at SR construction an error is thrown instructing the user to pip install mpi4jax.

This complicates the test matrix, however, as we should be running MPI tests with and without mpi4jax, which I still have to do.
Missing from this PR are tests and maybe some informational message when someone first runs netket

@PhilipVinc PhilipVinc force-pushed the PhilipVinc/mmagimmagia branch from 731086f to 4f256c9 Compare September 16, 2020 09:42
@PhilipVinc
Copy link
Member Author

This PR is completed.

I added some lines to the readme indicating that if one wants to use jax machines and SR with MPI one must install mpi4jax.

Scaling with this is very good, with approximately 75% increase in speed when doubling the number of processes at low process count.

For posterior-sake: what this PR does is essentially use the Allreduce function from mpi4jax that can be jitted/ADed/batched when performing SR.
This has a complication: as jax performs lazy operations, operations are not performed until someone either prints the content of a jax array, converts it to a numpy array, or flushes the operation cache. As we only print weights on rank 0, this can lead to deadlocks.
As such, when running with mpi4jax, after each optimiser step (therefore, after each time weights are updated) the queue of operations on them is flushed.

Lastly, this PR also enables jax omnistaging when mpi4jax is installed.
Omnistaging is a very new feature of jax (see linked issue for details). Jax will soon enable it by default .
It is needed for mpi4jax to ensure that mpi operations are never ignored by jax (he might think that they are not needed).
Omnistaging changes significantly the code generated by jax. However, from a little benchmarking I did, I saw that it usually increases performance for us, so it won't be detrimental.

@gcarleo gcarleo self-requested a review September 16, 2020 16:18
Copy link
Member

@gcarleo gcarleo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work @PhilipVinc, thanks a lot!!!

@gcarleo gcarleo merged commit 9dc763b into v3.0 Sep 16, 2020
@gcarleo gcarleo deleted the PhilipVinc/mmagimmagia branch September 16, 2020 16:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants