Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache tracing of (sub)calls when forming a jaxpr #9181

Closed
wants to merge 3 commits into from

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Jan 12, 2022

cc @hawkinsp

currently includes changes from #9188

The main motivations here are:

  1. reduce trace times (by not retracing functions to a jaxpr when we can just look up the result in a cache), and
  2. reduce compilation times by lowering outlined (as opposed to inlined) functions to our compiler targets.

This PR on its own tackles (1) but not (2). It sets us up for (2) because subjaxprs are now identical Python objects when we trace to a jaxpr, and so we can have have lowerings which are themselves memoized on (sub)jaxpr object id or alternatively we can have an early pass to produce a JaxprModule with explicit outlined functions. The latter seems more robust and clearer.

fixes #7155

TODO:

  • traceback surgery for source info (that's the last failing test)

@mattjj mattjj force-pushed the subcall-trace-caching branch 6 times, most recently from 517466b to 8197a1b Compare January 13, 2022 06:43
@mattjj mattjj requested a review from hawkinsp January 13, 2022 06:45
@mattjj mattjj force-pushed the subcall-trace-caching branch 4 times, most recently from fa2d2fe to b3cfa06 Compare January 13, 2022 17:19
@mattjj
Copy link
Collaborator Author

mattjj commented Jan 26, 2022

@patrick-kidger check out this PR if you haven't already! Steps towards a non-inlined and faster-compiling world.

@mattjj
Copy link
Collaborator Author

mattjj commented Jul 20, 2022

I think this is subsumed by #10775 (for the subjaxpr stuff) and #11298 (for the axis env stuff).

@mattjj mattjj closed this Jul 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

inner jit functions are re-traced (and re-compiled)
1 participant