Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added argument check to all primitives. #3197

Merged
merged 1 commit into from
May 24, 2020

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented May 24, 2020

  • Added argument check to all primitives.

The issue that inspired this is that lax.tie_in is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that lax.tie_in((x, x), const)
is the same as const even though x is a tracer.

This error would be caught previously if core.skip_checks == False
because then bind checks its arguments. I have essentially added
an unconditional argument check to bind.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the impl rule would fire and numpy
would report the error somehow, perhaps.

  • Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.

 * Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as jax-ml#2948 but reverted awaiting the fixes
in some user code.
@gnecula gnecula merged commit f1ae216 into jax-ml:master May 24, 2020
@gnecula gnecula deleted the strict_checks2 branch May 24, 2020 16:12
mattjj added a commit that referenced this pull request May 31, 2020
The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
mattjj added a commit that referenced this pull request May 31, 2020
The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
@mattjj mattjj mentioned this pull request May 31, 2020
mattjj added a commit that referenced this pull request Jun 1, 2020
revert find_top_trace change from #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jun 11, 2020
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as jax-ml#2948 but reverted awaiting the fixes
in some user code.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jun 11, 2020
revert find_top_trace change from jax-ml#3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants