You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using a boolean mask that is generated inside a jitted function (but only depends on static arguments) stops working with version 0.2.0. It worked in versions 0.1.77 and before. If the mask is passed in as a static argument it still works.
This is working as intended: jax 0.2.0 includes a major change ("omnistaging") to how JAX stages out computations. There's a longer document coming, but #3370 has a description of the change.
Briefly, there are two basic things you could here:
if you want the mask to be resolved at compile time, use a static argument and use classic NumPy for the computation in question. This avoids staging the masking computation out to XLA.
if you want the mask to be resolved at run time, use the 3-argument form of jnp.where which avoids the need for concreteness of boolean indices.
Problem
Using a boolean mask that is generated inside a jitted function (but only depends on static arguments) stops working with version
0.2.0
. It worked in versions0.1.77
and before. If the mask is passed in as a static argument it still works.Error Message
Minimal Example
Conda Environment
The text was updated successfully, but these errors were encountered: