Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify custom_jvp_call_p, remove custom_jvp_call_jaxpr_p #9137

Merged
merged 1 commit into from
Aug 18, 2022

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Jan 8, 2022

Remove custom_jvp_call_jaxpr_p and its transformation rules. They were superfluous! Instead use the new mechanism for converting from jaxpr params to bind params (in #9136).

This PR currently includes the commit from #9136, but it should be considered as a "diffbase".

The simplification in JaxprTrace.process_custom_jvp_call was actually made possible by omnistaging #3370, though we didn't apply it until now.

In a follow-up PR we'll delete custom_vjp_call_jaxpr_p too (or die trying). I'd like to land this one fully first to make sure this approach works (and hence ensure doing the same to custom_vjp_call_p makes sense).

@mattjj mattjj requested a review from froystig January 8, 2022 04:24
@mattjj mattjj force-pushed the simplify-custom-jvp-call branch 7 times, most recently from e914e53 to de1fa44 Compare January 8, 2022 20:35
@mattjj mattjj changed the title Simplify custom jvp call simplify custom_jvp_call_p, remove custom_jvp_call_jaxpr_p Jan 8, 2022
@@ -2666,8 +2666,7 @@ def test_escaped_tracer_transform_name(self):
_ = self._saved_tracer+1

def test_escaped_tracer_shape_dtype(self):
with self.assertRaisesRegex(core.UnexpectedTracerError,
r"shape \(4, 3\) and dtype int32"):
with self.assertRaisesRegex(core.UnexpectedTracerError, r"int32\[4,3\]"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this test because I improved the error message. I changed the error message in this PR because previously the error message formatting always attempted to print the shape and dtype associated with a Tracer, but that Tracer could be on a core.unit (which has no shape or dtype).

for t in out_tracers: t.recipe = eqn
return out_tracers
# We assume partial evaluation is only performed to build linear functions,
# and hence we don't need to keep the custom JVP rule around anymore.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that we already make this assumption, e.g. in post_process_custom_jvp_call. We just hadn't fully leveraged it because until #3370 we needed this approach for staging.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it turns out that we still need it for remat, until we upgrade to the new implementation in ad_checkpoint.py

@mattjj mattjj force-pushed the simplify-custom-jvp-call branch from 76d6592 to e76b622 Compare January 8, 2022 21:35
mattjj added a commit to mattjj/jax that referenced this pull request Jan 10, 2022
A few months ago we wrote a new version of remat to support custom
rematerialization policies. The new version is also built on a simpler
basic implementation. But instead of replacing the pre-existing remat
implementation, we kept both around because the new remat did not yet
support higher-order primitives like scan and cond. Tech debt!

I'd like to start moving on replacing the old remat implementation,
since it will enable some additional simplifications throughout the
system (like jax-ml#9137).

To that end, I'd like to first land a version of the new remat *which
only supports standard remat policies for scan and cond*. That is, we
won't yet have nonstandard remat policies which apply underneath scan
and cond.

This restricted version is useful because it's a drop-in replacement for
the old remat (which only supported a standard remat policy) and so
it'll let us delete the old implementation and remove the tech debt.

This PR updates all remat tests to check the new remat, and also adds
rules for scan and cond which work for default policies (and raise a
NotImplementedError for nonstandard policies).
mattjj added a commit to mattjj/jax that referenced this pull request Jan 10, 2022
A few months ago we wrote a new version of remat to support custom
rematerialization policies. The new version is also built on a simpler
basic implementation. But instead of replacing the pre-existing remat
implementation, we kept both around because the new remat did not yet
support higher-order primitives like scan and cond. Tech debt!

I'd like to start moving on replacing the old remat implementation,
since it will enable some additional simplifications throughout the
system (like jax-ml#9137).

To that end, I'd like to first land a version of the new remat *which
only supports standard remat policies for scan and cond*. That is, we
won't yet have nonstandard remat policies which apply underneath scan
and cond.

This restricted version is useful because it's a drop-in replacement for
the old remat (which only supported a standard remat policy) and so
it'll let us delete the old implementation and remove the tech debt.

This PR updates all remat tests to check the new remat, and also adds
rules for scan and cond which work for default policies (and raise a
NotImplementedError for nonstandard policies).
mattjj added a commit to mattjj/jax that referenced this pull request Jan 10, 2022
A few months ago we wrote a new version of remat to support custom
rematerialization policies. The new version is also built on a simpler
basic implementation. But instead of replacing the pre-existing remat
implementation, we kept both around because the new remat did not yet
support higher-order primitives like scan and cond. Tech debt!

I'd like to start moving on replacing the old remat implementation,
since it will enable some additional simplifications throughout the
system (like jax-ml#9137).

To that end, I'd like to first land a version of the new remat *which
only supports standard remat policies for scan and cond*. That is, we
won't yet have nonstandard remat policies which apply underneath scan
and cond.

This restricted version is useful because it's a drop-in replacement for
the old remat (which only supported a standard remat policy) and so
it'll let us delete the old implementation and remove the tech debt.

This PR updates all remat tests to check the new remat, and also adds
rules for scan and cond which work for default policies (and raise a
NotImplementedError for nonstandard policies).
@@ -298,119 +311,31 @@ def _apply_todos(todos, outs):

custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')

def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts):
# TODO(mattjj): could do more checking here...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why prevent core.check_call? If it's necessary, then can we still carry out some of the checks from core.check_call, if not all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_call doesn't quite work for custom_jvp_call; for example, there's no call_jaxpr parameter here. Moreover there are other things we might want to check here, like num_consts and maybe some things about jvp_jaxpr_thunk too.

So while I suspect we can factor out some shared helper functions, in general custom_jvp_call will need a custom typecheck rule.

Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I believe this tracks/reflects what we've been doing on the custom transpose front too, so this serves to confirm that as well in a sense.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Aug 18, 2022
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in jax-ml#9136).

This change languished until we could land jax-ml#11830 / jax-ml#11950 and friends. But now
we can!
@mattjj mattjj force-pushed the simplify-custom-jvp-call branch from ad71490 to 887b7ce Compare August 18, 2022 04:12
@copybara-service copybara-service bot merged commit af7d1c4 into jax-ml:main Aug 18, 2022
@mattjj mattjj deleted the simplify-custom-jvp-call branch August 18, 2022 05:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants