Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

omnistaging on by default #4038

Merged
merged 1 commit into from
Sep 15, 2020
Merged

omnistaging on by default #4038

merged 1 commit into from
Sep 15, 2020

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Aug 12, 2020

fixes #4114, fixes #3397, fixes #3108,

@google-cla google-cla bot added the cla: yes label Aug 12, 2020
@mattjj mattjj force-pushed the omnistaging-on-by-default branch 4 times, most recently from e7b74ca to 241267c Compare August 12, 2020 21:56
@tomhennigan
Copy link
Collaborator

:shipit:

@mattjj mattjj force-pushed the omnistaging-on-by-default branch 9 times, most recently from 1bce9c3 to 12916cf Compare August 14, 2020 20:46
@mattjj mattjj force-pushed the omnistaging-on-by-default branch from c5f2528 to 9cb19d6 Compare September 15, 2020 14:27
@mattjj mattjj marked this pull request as ready for review September 15, 2020 14:51
@mattjj mattjj merged commit 2678a46 into master Sep 15, 2020
@mattjj mattjj deleted the omnistaging-on-by-default branch September 15, 2020 15:27
mattjj added a commit that referenced this pull request Sep 18, 2020
Previously, given this function:

```python
@jax.jit
def f(x,y):
  if x > y:
    return x
  else:
    return y
```

we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):

```
...

While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:

  operation c:bool[] = gt a:int32[] b:int32[]
    from line tim.py:5 (f)

...
```

But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.

After this change, we instead produce this error message:

```
...

While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.

...
```

I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
hawkinsp added a commit to hawkinsp/jax that referenced this pull request Mar 4, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Mar 17, 2021
Updated version of jax-ml#4536.

This is removing the device constant part of jax-ml#1668. We can do this because after jax-ml#3370 and jax-ml#4038 omnistaging removes the need for lazy device constants in a jitted context. (They could still in principle be useful in an op-by-op context, but the power:weight isn't worthwhile anymore.)

After this change, the only parts of the lazy sublanguage that remain are those to do with broadcasts and transposes. We may or may not kill those in a follow-up (it hinges on whether any benefit to op-by-op execution is worth the extra complexity).

This change regresses non-omnistaging users. As one particular example, test_eval_shape_big_random_array no longer passes with omnistaging disabled.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
3 participants