cache tracing of (sub)calls when forming a jaxpr #9181
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
cc @hawkinsp
currently includes changes from #9188
The main motivations here are:
This PR on its own tackles (1) but not (2). It sets us up for (2) because subjaxprs are now identical Python objects when we trace to a jaxpr, and so we can have have lowerings which are themselves memoized on (sub)jaxpr object id or alternatively we can have an early pass to produce a
JaxprModule
with explicit outlined functions. The latter seems more robust and clearer.fixes #7155
TODO: