-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
simplify custom_jvp_call_p, remove custom_jvp_call_jaxpr_p #9137
Conversation
e914e53
to
de1fa44
Compare
@@ -2666,8 +2666,7 @@ def test_escaped_tracer_transform_name(self): | |||
_ = self._saved_tracer+1 | |||
|
|||
def test_escaped_tracer_shape_dtype(self): | |||
with self.assertRaisesRegex(core.UnexpectedTracerError, | |||
r"shape \(4, 3\) and dtype int32"): | |||
with self.assertRaisesRegex(core.UnexpectedTracerError, r"int32\[4,3\]"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this test because I improved the error message. I changed the error message in this PR because previously the error message formatting always attempted to print the shape and dtype associated with a Tracer, but that Tracer could be on a core.unit (which has no shape or dtype).
for t in out_tracers: t.recipe = eqn | ||
return out_tracers | ||
# We assume partial evaluation is only performed to build linear functions, | ||
# and hence we don't need to keep the custom JVP rule around anymore. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice that we already make this assumption, e.g. in post_process_custom_jvp_call
. We just hadn't fully leveraged it because until #3370 we needed this approach for staging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it turns out that we still need it for remat
, until we upgrade to the new implementation in ad_checkpoint.py
76d6592
to
e76b622
Compare
A few months ago we wrote a new version of remat to support custom rematerialization policies. The new version is also built on a simpler basic implementation. But instead of replacing the pre-existing remat implementation, we kept both around because the new remat did not yet support higher-order primitives like scan and cond. Tech debt! I'd like to start moving on replacing the old remat implementation, since it will enable some additional simplifications throughout the system (like jax-ml#9137). To that end, I'd like to first land a version of the new remat *which only supports standard remat policies for scan and cond*. That is, we won't yet have nonstandard remat policies which apply underneath scan and cond. This restricted version is useful because it's a drop-in replacement for the old remat (which only supported a standard remat policy) and so it'll let us delete the old implementation and remove the tech debt. This PR updates all remat tests to check the new remat, and also adds rules for scan and cond which work for default policies (and raise a NotImplementedError for nonstandard policies).
A few months ago we wrote a new version of remat to support custom rematerialization policies. The new version is also built on a simpler basic implementation. But instead of replacing the pre-existing remat implementation, we kept both around because the new remat did not yet support higher-order primitives like scan and cond. Tech debt! I'd like to start moving on replacing the old remat implementation, since it will enable some additional simplifications throughout the system (like jax-ml#9137). To that end, I'd like to first land a version of the new remat *which only supports standard remat policies for scan and cond*. That is, we won't yet have nonstandard remat policies which apply underneath scan and cond. This restricted version is useful because it's a drop-in replacement for the old remat (which only supported a standard remat policy) and so it'll let us delete the old implementation and remove the tech debt. This PR updates all remat tests to check the new remat, and also adds rules for scan and cond which work for default policies (and raise a NotImplementedError for nonstandard policies).
A few months ago we wrote a new version of remat to support custom rematerialization policies. The new version is also built on a simpler basic implementation. But instead of replacing the pre-existing remat implementation, we kept both around because the new remat did not yet support higher-order primitives like scan and cond. Tech debt! I'd like to start moving on replacing the old remat implementation, since it will enable some additional simplifications throughout the system (like jax-ml#9137). To that end, I'd like to first land a version of the new remat *which only supports standard remat policies for scan and cond*. That is, we won't yet have nonstandard remat policies which apply underneath scan and cond. This restricted version is useful because it's a drop-in replacement for the old remat (which only supported a standard remat policy) and so it'll let us delete the old implementation and remove the tech debt. This PR updates all remat tests to check the new remat, and also adds rules for scan and cond which work for default policies (and raise a NotImplementedError for nonstandard policies).
@@ -298,119 +311,31 @@ def _apply_todos(todos, outs): | |||
|
|||
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call') | |||
|
|||
def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts): | |||
# TODO(mattjj): could do more checking here... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why prevent core.check_call
? If it's necessary, then can we still carry out some of the checks from core.check_call
, if not all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_call
doesn't quite work for custom_jvp_call
; for example, there's no call_jaxpr
parameter here. Moreover there are other things we might want to check here, like num_consts
and maybe some things about jvp_jaxpr_thunk
too.
So while I suspect we can factor out some shared helper functions, in general custom_jvp_call
will need a custom typecheck rule.
e76b622
to
ad71490
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I believe this tracks/reflects what we've been doing on the custom transpose front too, so this serves to confirm that as well in a sense.
They were superfluous! Instead use the "new" mechanism for converting from jaxpr params to bind params (in jax-ml#9136). This change languished until we could land jax-ml#11830 / jax-ml#11950 and friends. But now we can!
ad71490
to
887b7ce
Compare
Remove custom_jvp_call_jaxpr_p and its transformation rules. They were superfluous! Instead use the new mechanism for converting from jaxpr params to bind params (in #9136).
This PR currently includes the commit from #9136, but it should be considered as a "diffbase".The simplification in JaxprTrace.process_custom_jvp_call was actually made possible by omnistaging #3370, though we didn't apply it until now.
In a follow-up PR we'll delete custom_vjp_call_jaxpr_p too (or die trying). I'd like to land this one fully first to make sure this approach works (and hence ensure doing the same to custom_vjp_call_p makes sense).