Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autodidax: jit, DeviceArrays, multi-output, pytrees #5856

Merged
merged 1 commit into from
Mar 9, 2021

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Feb 26, 2021

The changes here, in no particular order, are roughly:

  • Added transpose and broadcast primitives, b/c they're needed by vmap. (We could instead switch to a simplified version of vmap that doesn't need these, but this way it's a bit more faithful to the real jax.vmap.)
  • Updated jaxprs, primitives, and transformations to be multi-output, since we need that for jvp (especially jvp-of-jit).
  • For the preceding bullet, added pytrees!
  • Added jit, with caching and DeviceArrays, as well as jvp and vmap rules for it. That entailed adding an eval_jaxpr, and not much else!
  • Fixed a subtle bug where we drop references to Tracers and their memory addresses get reused. It's funny because I had the same bug in omnistaging #3370.
  • Added some text about "initial-style" vs "final-style".
$ cloc docs/autodidax.py
       1 text file.
       1 unique file.
       0 files ignored.

github.com/AlDanial/cloc v 1.86  T=0.01 s (71.0 files/s, 117368.4 lines/s)
-------------------------------------------------------------------------------
Language                     files          blank        comment           code
-------------------------------------------------------------------------------
Python                           1            295            460            898
-------------------------------------------------------------------------------

I think we'll be able to get grad in here at under 1000 SLoC, though probably not final-style calls (i.e. custom_jvp) or pmap.

@google-cla google-cla bot added the cla: yes label Feb 26, 2021
@mattjj mattjj force-pushed the autodidax branch 4 times, most recently from d4d2765 to a5a4b68 Compare February 27, 2021 06:27
@mattjj mattjj force-pushed the autodidax branch 4 times, most recently from 056bca0 to 016400b Compare March 6, 2021 03:37
@mattjj mattjj changed the title start autodidax jit autodidax: jit, DeviceArrays, multi-output, pytrees Mar 6, 2021
@mattjj mattjj requested a review from froystig March 6, 2021 03:49
@mattjj mattjj marked this pull request as ready for review March 6, 2021 03:49
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

@mattjj mattjj added the pull ready Ready for copybara import and testing label Mar 9, 2021
@copybara-service copybara-service bot merged commit 0b88b0e into master Mar 9, 2021
@mattjj mattjj deleted the autodidax branch March 9, 2021 04:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants