Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update JAX Quickstart with parallelization (pmap), other changes #3520

Closed
wants to merge 4 commits into from
Closed

Update JAX Quickstart with parallelization (pmap), other changes #3520

wants to merge 4 commits into from

Conversation

8bitmp3
Copy link
Contributor

@8bitmp3 8bitmp3 commented Jun 22, 2020

PR ref: #3458

  • Added a section covering pmap (using accelerators in parallel) with the material from @skye NeurIPS 2019 demo, main JAX GitHub README, pmap API docs
  • Added annotation to code examples to help new users
  • Added that JAX is a system for general numerical computing, since it's not just for machine learning research (ref: NeurIPS 2019 demo)
  • Expanded the Hessian example for the vectorized map (vmap) - please check if it's fine
  • Added more info on the fwd/reverse-mode Jacobian matrices, using The Autodiff Cookbook as a reference (also fixed two small links in that notebook as a result)
  • Changed importing JAX NumPy to as jnp instead of np in line with other examples (we have ordinary NumPy as onp, so it may be confusing otherwise for new users)
  • Mentioned, just in case, that vmap is a lot like a regular Python map (but...)
  • Added links to source code/APIs (e.g. device_put, vjp, @jit etc etc
  • Expanded the intro, some other paragraphs included a small "table of contents" thing
  • Added that "JAX is fast, easy to use, and uses a functional programming model..."
  • Added more info about JIT (almost no handwritten kernels, you get almost "eager-mode" performance)
  • Added a sentence about how user-friendly the API is ("pass a function into a transform and get a function out...")
  • Made other minor changes using Google's doc style guide (addressing the user as "you" instead of "we", and other stuff)

Most of these are just UX improvement suggestions. Feel free to rip them apart/propose new changes/propose to keep the old stuff.

cc @mattjj @shoyer

Thanks for your work 👍

@skye
Copy link
Member

skye commented Jun 22, 2020

Blah, this breaks RTD because it can't run the notebook without a Cloud TPU :( I think this means we have to choose between pmap examples (requires multiple devices --> Cloud TPU) and automated testing/generation of the Quickstart notebook via RTD. I personally think the pmap examples are worth it (how often do we break the quickstart notebook?), but I could be convinced otherwise. @gnecula @shoyer do have any alternate ideas, or opinions on what we should do here?

(I was wondering if we could add conditional guards so RTD only runs what it can, but I think this means we wouldn't have pmap output at https://jax.readthedocs.io/en/latest/notebooks/quickstart.html.)

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jun 22, 2020

@skye just a thought:

Would adding the quickstart to exclude_patterns in https://github.com/google/jax/blob/master/docs/conf.py make sense? These notebooks have long computations, apparently:

exclude_patterns = [
    # Slow notebook: long time to load tf.ds
    'notebooks/neural_network_with_tfds_data.ipynb',
    # Slow notebook
    'notebooks/Neural_Network_and_Data_Loading.ipynb',
    'notebooks/score_matching.ipynb',
    'notebooks/maml.ipynb',
    # Fails with shape error in XL
    'notebooks/XLA_in_Python.ipynb',
    # Sometimes sphinx reads its own outputs as inputs!
    'build/html',
] 

Also, to keep the original link/reference on the ReadTheDocs site, you could manually edit https://github.com/google/jax/blob/master/docs/index.rst or something like that? I'm not too familiar with this process yet.

image

Are there any plans to add any pmap notebooks in the future?

@skye
Copy link
Member

skye commented Jun 22, 2020

My understanding is that adding the quickstart to exclude_patterns means we don't actually run that notebook as part of our documentation build (including when built by RTD), which is what I was suggesting in my previous comment. So yes, we could still link to the notebook from the docs and have a rendered version of it outside of colab, but it wouldn't be dynamically generated by running the notebook anymore. The downside is we won't automatically trigger an error if we make code changes that break the notebook somehow, or even if we don't introduce an error, the saved notebook output could become out-of-date until we manually regenerate it.

I don't think there are any immediate plans for a pmap notebook beyond the pmap cookbook at https://github.com/google/jax/tree/master/cloud_tpu_colabs.

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jun 22, 2020

@skye I can set the default Colab notebook accelerator to GPU and move the extra TPU settings cell down to the vmap section (you can optionally run it). That way, most of the notebook shouldn’t break the RTD, right? 🤔

@shoyer
Copy link
Collaborator

shoyer commented Jun 23, 2020

RTD can only execute using a CPU. So we can't use pmap on it.

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jun 23, 2020

@shoyer maybe I can switch the default hardware in the Colab notebook to CPU while keeping the output. Then commit that new ipynb here. We can add a note that says if you want to run pmap, click Open in Colab at the top of the page. WDYT

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jun 23, 2020

With import ☕️, I got the build to work.

Given the lack of pmap support in RTD, is this (temp) solution alright with the team @skye @shoyer @mattjj ?

  • Switched the default hardware to CPU in Colab.
  • Added instructions on what to do to run pmap examples (click on run in Colab, etc) .
  • Converted all the TPU driver config cell and the rest of the pmap code to markdown with example outputs.
    • (I looked at the logs of failed builds and noticed that RTD trips when attempting to run unfamiliar cells and converting them to HTML.)

Read the Docs build should recognize this as simple markdown. It's a temp workaround.

(Further down the line: I can investigate how to run Notebooks in other OS doc building solutions, such as Hugo with Google's Docsy theme - used by Kubeflow and Kubernetes projects. Apparently, their build speed is good (the site is static too), and Hugo may have notebook-integration.)

@shoyer
Copy link
Collaborator

shoyer commented Jun 23, 2020

@8bitmp3 thanks for your work on this!

I'll let @skye take the lead on reviewing. But I did want to note that you might want to take a look at the rendered version of the notebook by following the "docs/readthedocs.org:jax" link under "Checks" : https://jax--3520.org.readthedocs.build/en/3520/notebooks/quickstart.html

One thing you'll notice is that unfortunately formatted links don't work. This appears to be a limitation of reStructuredText: spatialaudio/nbsphinx#301

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jun 23, 2020

@skye Check out the new commit that fixes external link rendering in RTD/Sphinx (see details below). Let me know what you think about this and other workarounds. I've built the Docs site locally and all looks OK so far.

image

Thanks @shoyer for spotting this link issue in Sphinx/reStructuredText. Another JAX notebook here had a similar issue that I spotted recently and fixed by attaching the hyperlinks to plain words (see 86fcc77 and f351da7)

In the Quickstart, I use the following format for functions and methods that have API doc links:

[pmap()](https://jax.readthedocs.io/...)

which will look like plain text with round brackets: e.g. ohFunctionMyFunction() — notice the lack of back ticks. They causes the original issues (see below).

On the other hand, they're blue and not red.

More background:

For external links, Sphinx follows this rule, which already uses back ticks:

This is the preferred way in RTD/Sphinx:

`pmap <https://jax.readthedocs.io/...>`_

Meanwhile, this clashes with standard Markdown, which Sphinx supports, just not fully:

This is normal Markdown outside of RTD/Sphinx:

[`pmap`](https://jax.readthedocs.io/...)

(See: https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html#external-links)

So, adding more back ticks around the name of functions that we want to link the API docs to only breaks the sequence/links/everything else.

For everything else that doesn't require external links, we leave them as before: jax.random, vmap, @jit (vs @jit), etc.

@skye
Copy link
Member

skye commented Jun 24, 2020

I like the explicit instructions on how to enable the TPU code! It has the added benefit of teaching new users how to use Cloud TPU colabs :) However, I still think it'd be better to manually generate the notebook so we can have actual code cells + output in the pmap section, instead of requiring users to copy/paste. There's nothing here that's expected to change, so I think the added usability is worth the slight risk of a stale notebook.

So specifically, add quickstart.ipynb to exclude_patterns in conf.py, and then save a version of the notebook with outputs (looks like you already did this). I think that should be it...

Also, even though we could switch to TPU at the beginning with this scheme, I like how you do it here and switch just above the pmap examples. That way readers won't be scared away by the preamble at the beginning, and we get the added educational benefit.

WDYT?

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jun 25, 2020

@skye That's awesome, thank you so much for the feedback 🙏

  • Move all the pmap code back to code cells
  • Add quickstart.ipynb to exclude_patterns = [ ...] in conf.py
  • NEW ISSUE (see below) Save the notebook as a notebook with outputs
    - Set the default hardware in the notebook: CPU or TPU? (See Solution 1 and 2 below)
  • NEW FEATURE: Add more NN lib references—Google Brain's Flax and DeepMind's Rlax—to ... For a neural network example ... or the much more capable [Trax library](https://github.com/google/trax/) - yay or nay, LMK

New issue: If you set the default hardware to TPU in the beginning and run @jit, @Grad, and @pmap stuff—everything is fine as expected. After running these computation, setting the TPU driver as the backend creates a new issue: you get only one device. If you run a pmap on anything (e.g. lambda x: x ** 2) with the numbers of XLA machines set to 8—as before—via jnp.arange, you're told you only have one device available (The pmap Cookbook has an example with 7):

...
/usr/local/lib/python3.6/dist-packages/jax/interpreters/pxla.py in parallel_callable(fun, backend, axis_name, axis_size, global_axis_size, devices, name, mapped_invars, donated_invars, *avals)
    785              "devices are available (num_replicas={}, num_partitions={})")
    786       raise ValueError(msg.format(num_global_shards, xb.device_count(backend),
--> 787                                   num_global_replicas, num_partitions))
    788 
    789     # On a single host, we use the platform's default device assignment to

ValueError: compiling computation that requires 8 logical devices, but only 1 XLA devices are available (num_replicas=8, num_partitions=1)

Solution 1: Run the notebooks with CPUs until you reach pmap, switch to TPUs, set the TPU backend, run pmap example cells 👍

Solution 2: If TPU is enabled from the beginning, run the TPU driver backend config cell first, perform the rest of the calculations thereafter 🤔

Solution 2 creates a less desirable UX by undoing this:

Also, even though we could switch to TPU at the beginning with this scheme, I like how you do it here and switch just above the pmap examples. That way readers won't be scared away by the preamble at the beginning, and we get the added educational benefit.

WDYT

Thank you!

(What if TPUv2 allocation is carried out with RL (Hierarchical Planning for Device Placement.))

@skye
Copy link
Member

skye commented Jun 25, 2020

Let's leave out the NN library references. Unfortunately it's pretty hard to review notebook changes, so I'd like to keep this change as minimal as possible (I haven't actually looked at all the individual edits you've made yet, I'll do that once we get the overall structure nailed down). Plus the rest of the core team / NN library authors might have opinions about that :)

For the runtime issue, I agree Solution 1 is better (let's run on GPU until we switch to TPU though, so the microbenchmarks still look good). Also note you may have to import jax, etc. again after switching the runtime, since it resets your state.

(What if TPUv2 allocation is carried out with RL (Hierarchical Planning for Device Placement.))

😆

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jun 25, 2020

Done @skye.

New:

  • Set the default hardware in the notebook: GPU ("accelerator": "GPU" in JSON speak)
  • Import jax libs again when configuring the TPU driver after runtime switching in the pmap section
  • Add a reference to the notebook in docs/index.rst, since JAX Quickstart is not under Tutorials any more:
For an introduction to JAX, start at the `JAX GitHub page <https://github.com/google/jax>`_. Then, try it yourself with the `JAX Quickstart notebook in Google Colab <https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/quickstart.ipynb>`_.

  • Update READMEs to "2020" if that's 👍

As before:

  • Move all the pmap code back to code cells
  • Add quickstart.ipynb to exclude_patterns = [ ... ] in conf.py

Thanks @skye. It's ready for your review.

Copy link
Member

@skye skye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update READMEs to "2020"

We don't update copyrights actually, I guess they're supposed to reflect when the file was first created.

Add a reference to the notebook in docs/index.rst

Can you also put a link at the top of the Tutorials section immediately below? Basically how it is currently. I think it's good to have it listed there too for people like me who skim over prose :) And also to have a link to the HTML version.

The overall structure looks good, but I'm having trouble reviewing the exact changes to the notebook because there are so many, and it's hard to review notebooks. Would it be possible to separate this out into two changes, one containing the changes to the notebook above the new pmap section, and one with everything else (the new pmap section, disabling RTD, etc.)? Probably the pmap + Cloud TPU change should come first. If it's a big pain to split it up at this point I can review the whole thing (maybe with the help of ReviewNB, lemme know.

@skye
Copy link
Member

skye commented Jun 26, 2020

Hmm I guess posting the above comment also published the in-file comments from when I started reviewing, before realizing this would take a long time. I deleted them to avoid confusion.

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jul 1, 2020

To make the review process easier @skye I’ll reverse the commits and then introduce separate ones for pmap, the rest of the notebook, and changes to.rst files.

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jul 4, 2020

@skye @shoyer I may have found a long-term solution to working with pmap notebooks (that need TPUs) in a ReadTheDocs-hosted environment. Instead of using the Sphinx site generator, you can use MkDocs and an MkDocs Jupyter plugin that doesn't require executing notebooks.

Blah, this breaks RTD because it can't run the notebook without a Cloud TPU :( I think this means we have to choose between pmap examples (requires multiple devices --> Cloud TPU) and automated testing/generation of the Quickstart notebook via RTD.

RTD can only execute using a CPU. So we can't use pmap on it.

The current JAX site generator is Sphinx that uses on reStructured text. Like Sphinx, MkDocs is also recommended by ReadTheDocs. Unlike Sphinx though, it uses Markdown.

For Jupyter notebooks, there's a plugin that allows to not only render .ipynb files, but also view them without having to pre-run them (you simply state execute: False in a YAML).

Below is a screenshot from a simple docs site I built. This is a the pmap cookbook notebook, which runs on TPUs by default:

image

The docs site is stored in a /docs directory inside the repo, just like the current docs site. You configure plugins and extensions in a YAML file. There's a fast search bar and the theme I used is free (it's called Material). There're options to enable text and code highlights and all the other modern bells and whistles within that Material Theme. For example, I added a plugin to show the date of the last commit at the end of each page to easily identify outdated docs.

By the way, the search is fast—it shows the results as you type:

image

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jul 5, 2020

Fyi, It's possible to use MkDocs instead of Sphinx and keep the ReadTheDocs theme (it comes pre-loaded and you set it as theme: name: readthedocs.

Below is another example with the blue color scheme in the material theme:

The current JAX site generator is Sphinx that uses on reStructured text. Like Sphinx, MkDocs is also recommended by ReadTheDocs. Unlike Sphinx though, it uses Markdown.

For Jupyter notebooks, there's a plugin that allows to not only render .ipynb files, but also view them without having to pre-run them (you simply state execute: False in a YAML).

Below is a screenshot from a simple docs site I built.

This is the JAX+TFDS notebook that's not pointing to an external link like it currently does (➡️ https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb). You're simply viewing the notebook in http(s)://<jax.host.io>/notebooks/neural_network_with_tfds_data/, same as with the pmap notebook:

image

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jul 6, 2020

Another suggestion @skye @shoyer

  • What if we don't let the Sphinx plugin execute the notebooks when building the docs site?
  • Instead, we run the cells in Colab from top to bottom before saving the notebooks while keeping all the cell output.
  • Sphinx's Jupyter notebook plugin - nbsphinx - would render only the input and the output (which has been pre-run in Colab in advance).

I got this idea because in MkDocs (a Sphinx alternative) there's a plugin called MkDocs Jupyter that allows to:

view them without having to pre-run them (you simply state execute: False in a YAML).

And in Sphinx - where you use the nbsphinx plugin for Jupyter - there's an option to never execute.

So, in conf.py you change the settings to never from always:

...
# Execute notebooks before conversion: 'always', 'never', 'auto' (default)
# We execute all notebooks, exclude the slow ones using 'exclude_patterns'
nbsphinx_execute = 'never'
...

For example, this is the Auto-Batching Log-Densities Example notebook that shows which cells were run by the original author before saving the notebook:

image

(The notebook cells weren't run by Sphinx/nbsphinx.)

If this works for all the notebooks, then, technically, you won't need to add a bunch of notebooks to exclude patterns because of their slowness to execute when building:

exclude_patterns = [
    # Slow notebook: long time to load tf.ds
    'notebooks/neural_network_with_tfds_data.ipynb',
    # Slow notebook
    'notebooks/Neural_Network_and_Data_Loading.ipynb',
    #'notebooks/score_matching.ipynb',
    'notebooks/maml.ipynb',
    # Fails with shape error in XL
    'notebooks/XLA_in_Python.ipynb',
    # Sometimes sphinx reads its own outputs as inputs!
    'build/html',
]

WDYT?

@8bitmp3 8bitmp3 closed this Jul 9, 2020
To enable TPU support in notebooks with the `pmap` transform in Sphinx, notebooks should not be executed during build before conversion. Here, you change 'always' to 'never', such that `nbsphinx_execute = 'never'`.
@8bitmp3 8bitmp3 reopened this Jul 10, 2020
@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Jul 10, 2020

@skye I've refactored the JAX Quickstart notebook in stages with minimum changes to the original layout of the site to help you with the review:

Commits 1 d262a3d and 2 74e1d61 address this:

Would it be possible to separate this out into two changes, one containing the changes to the notebook above the new pmap section, and one with everything else

For Jupyter notebooks with TPU mentioned in the JSONs - commits 3 4115c01 and 4 728dd96 address this:

What if we don't let the Sphinx plugin execute the notebooks when building the docs site?
Instead, we run the cells in Colab from top to bottom before saving the notebooks while keeping all the cell output.
Sphinx's Jupyter notebook plugin - nbsphinx - would render only the input and the output (which has been pre-run in Colab in advance).

I have built and run the site locally, and the Sphinx (RTD)'s nbsphinx plugin for Jupyter notebooks goes straight to converting ipynb files to html or whatnot, without running them first. This means that the output that we have in the original notebooks stays there (and Sphinx doesn't have to execute each cell before converting the notebooks).

And it looks like so far there's not need in any workarounds, such as:

So specifically, add quickstart.ipynb to exclude_patterns in conf.py, and then save a version of the notebook with outputs (looks like you already did this). I think that should be it...

I think if this works for you, the other notebooks that are currently in exclude_patterns can be added back.

LMKWYT

Thanks 🙌

@skye
Copy link
Member

skye commented Jul 10, 2020

Thanks for separating this out, and for looking into solutions for the TPU execution!

What if we don't let the Sphinx plugin execute the notebooks when building the docs site?

We intentionally execute notebooks as part the docs build process to make sure they're up-to-date and to block submitting PRs that break the notebooks. Ideally we'd be able to execute all notebooks, but as you've noticed, some of the notebooks are too slow to generate this way. So we definitely don't want to set nbsphinx_execute = 'never'. That's why I suggested adding the quickstart notebook to exclude_patterns.

MkDocs

I like these material themes, but let's try to limit the scope of this change :)

I'll take a look at the specific changes separately.

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented Aug 27, 2020

Hi @skye 👋 Hope y'all are safe and stuff!
Returning to this open PR 🙌

Thanks for separating this out, and for looking into solutions for the TPU execution!

What if we don't let the Sphinx plugin execute the notebooks when building the docs site?

We intentionally execute notebooks as part the docs build process to make sure they're up-to-date and to block submitting PRs that break the notebooks. Ideally we'd be able to execute all notebooks, but as you've noticed, some of the notebooks are too slow to generate this way. So we definitely don't want to set nbsphinx_execute = 'never'. That's why I suggested adding the quickstart notebook to exclude_patterns.

Let's finish this task. How flexible is the team when it comes to configuring Sphinx in general? Or have y'all resolved the Cloud TPU incompatibility?

@skye
Copy link
Member

skye commented Sep 1, 2020

Hi, sorry for the delay on this. I got swamped with some other work (and was on vacation for the last week), but agree we should get this done!

To make this simpler, can you remove the "2-Expand the intro, ..." commit from this PR and and post it as a separate PR? That will make each PR much more self-contained and quicker to review, especially if others have opinions. And also change "Update JAX Quickstart with parallelization (pmap), other changes" to just exclude this notebook, as discussed above.

I'm actually out-of-office most of today too, but will try to review the pmap notebook changes this week.

@8bitmp3
Copy link
Contributor Author

8bitmp3 commented May 5, 2021

We can revisit in a set of small PRs @skye 👍

@8bitmp3 8bitmp3 closed this May 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants