Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve a ConcretizationTypeError message from dependence on jitted function arguments #4342

Merged
merged 7 commits into from
Sep 26, 2020

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Sep 18, 2020

Previously, given this function:

import jax

@jax.jit
def f(x,y):
  if x > y:
    return x
  else:
    return y

we'd get an error message like this (after #4038, improved to help with omnistaging debugging):

...

While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:

  operation c:bool[] = gt a:int32[] b:int32[]
    from line tim.py:5 (f)

...

But this message is buggy! In this case, the value is a tracer because it has a data dependence on arguments to a jitted function.

After this change, we instead produce this error message:

...

While tracing the function f at tim.py:3, this concrete value was not available
in Python because it depends on the value of the arguments to f at tim.py:3
at positions [0, 1], and the computation of these values is being staged out.

...

I'm eager to iterate with further improvements, but for now I want to fix this buggy message.

Previously, given this function:

```python
@jax.jit
def f(x,y):
  if x > y:
    return x
  else:
    return y
```

we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):

```
...

While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:

  operation c:bool[] = gt a:int32[] b:int32[]
    from line tim.py:5 (f)

...
```

But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.

After this change, we instead produce this error message:

```
...

While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.

...
```

I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
@google-cla google-cla bot added the cla: yes label Sep 18, 2020
@mattjj mattjj changed the title Improve a tracer error message improve a ConcretizationTypeError message from dependence on jitted function arguments Sep 18, 2020
@mattjj mattjj requested a review from jakevdp September 25, 2020 22:13
@mattjj mattjj added the pull ready Ready for copybara import and testing label Sep 25, 2020
@copybara-service copybara-service bot merged commit 5b3cbc5 into master Sep 26, 2020
@jakevdp jakevdp deleted the improve-tracer-error branch October 6, 2021 19:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants