-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
revisions to #3197 #3264
revisions to #3197 #3264
Conversation
The previous version was written and tested for performance; the revised version caused at least a 25% slowdown in the dispatch time of `lax.add(1, 2)` (and so likely a much bigger slowdown for the find_top_trace timing alone). Instead, we can just change the error message in xla.abstractify, since invalid types lead to abstractification errors when we apply primitive impls.
tests/api_test.py
Outdated
def test_primitive_jaxtype_error(self): | ||
with self.assertRaisesRegex( | ||
TypeError, "Argument .* of type .* is not a valid JAX type"): | ||
lax.add(1, 'hi') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surprised that the error in xla.abstractify is triggered here. I can't quite tell how. Probably by way of core._valid_jaxtype
which is only called by core._check_arg
but then I do not see how lax.add
will call _check_arg. Is lax.add
a compiled primitive?
I see now, all arithmetic operators are compiled. This solution then does not work for the non-compiled lax primitives.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right yeah. In op-by-op mode the impls should check the types of their arguments, and any impl based on xla.apply_primitive
needs to go through xla and hence xla.abstractify
. In jaxpr tracing, the type checking should happen in the abstract eval rules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you intentionally drop
assert skip_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
from bind? Also for speed?
No, that was not intentional. This PR didn't edit bind, but I see now that #3197 did, and I failed to restore that bit. (Things gated by skip_checks can be expensive!) Thanks for the catch! |
revert find_top_trace change from jax-ml#3197 The previous version was written and tested for performance; the revised version caused at least a 25% slowdown in the dispatch time of `lax.add(1, 2)` (and so likely a much bigger slowdown for the find_top_trace timing alone). Instead, we can just change the error message in xla.abstractify, since invalid types lead to abstractification errors when we apply primitive impls.
The previous version of
core.find_top_trace
was written and tested for performance; the revised version of #3197 caused at least a 25% slowdown in the dispatch time oflax.add(1, 2)
(and so likely a much bigger slowdown for the find_top_trace timing alone).Moreover, the check isn't so natural in core.py. Instead, we can just change the error message in xla.abstractify, since invalid types lead to abstractification errors when we apply primitive impls.
Revises #3197