Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revisions to #3197 #3264

Merged
merged 5 commits into from
Jun 1, 2020
Merged

revisions to #3197 #3264

merged 5 commits into from
Jun 1, 2020

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented May 31, 2020

The previous version of core.find_top_trace was written and tested for performance; the revised version of #3197 caused at least a 25% slowdown in the dispatch time of lax.add(1, 2) (and so likely a much bigger slowdown for the find_top_trace timing alone).

Moreover, the check isn't so natural in core.py. Instead, we can just change the error message in xla.abstractify, since invalid types lead to abstractification errors when we apply primitive impls.

Revises #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
@mattjj mattjj requested a review from gnecula May 31, 2020 17:37
def test_primitive_jaxtype_error(self):
with self.assertRaisesRegex(
TypeError, "Argument .* of type .* is not a valid JAX type"):
lax.add(1, 'hi')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised that the error in xla.abstractify is triggered here. I can't quite tell how. Probably by way of core._valid_jaxtype which is only called by core._check_arg but then I do not see how lax.add will call _check_arg. Is lax.add a compiled primitive?

I see now, all arithmetic operators are compiled. This solution then does not work for the non-compiled lax primitives.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right yeah. In op-by-op mode the impls should check the types of their arguments, and any impl based on xla.apply_primitive needs to go through xla and hence xla.abstractify. In jaxpr tracing, the type checking should happen in the abstract eval rules.

Copy link
Collaborator

@gnecula gnecula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intentionally drop

assert skip_checks or all(isinstance(arg, Tracer)
                              or valid_jaxtype(arg) for arg in args), args

from bind? Also for speed?

@mattjj
Copy link
Collaborator Author

mattjj commented May 31, 2020

Did you intentionally drop [...] from bind

No, that was not intentional. This PR didn't edit bind, but I see now that #3197 did, and I failed to restore that bit. (Things gated by skip_checks can be expensive!)

Thanks for the catch!

@mattjj mattjj merged commit 49a441f into master Jun 1, 2020
@mattjj mattjj deleted the revise-3197 branch June 1, 2020 20:24
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jun 11, 2020
revert find_top_trace change from jax-ml#3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants