Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change in behavior of boolean masks inside jitted functions in jax 0.2.0 #4471

Closed
janosg opened this issue Oct 7, 2020 · 3 comments
Closed
Labels
question Questions for the JAX team

Comments

@janosg
Copy link

janosg commented Oct 7, 2020

Problem

Using a boolean mask that is generated inside a jitted function (but only depends on static arguments) stops working with version 0.2.0. It worked in versions 0.1.77 and before. If the mask is passed in as a static argument it still works.

Error Message

Traceback (most recent call last):
  File "mixed_documents/jax_error.py", line 23, in <module>
    wont_work(x, y)
  File "mixed_documents/jax_error.py", line 11, in create_mask_inside_jitted_function
    return x[mask].sum()
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 3783, in _rewriting_take
    treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 3843, in _split_index_for_jit
    idx = _expand_bool_indices(idx)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 4102, in _expand_bool_indices
    raise IndexError("Array boolean indices must be concrete.")
jax.traceback_util.FilteredStackTrace: IndexError: Array boolean indices must be concrete.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "mixed_documents/jax_error.py", line 23, in <module>
    wont_work(x, y)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/traceback_util.py", line 137, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/api.py", line 209, in f_jitted
    out = xla.xla_call(
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/core.py", line 1144, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/core.py", line 1135, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/core.py", line 1147, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/core.py", line 577, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/interpreters/xla.py", line 529, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/linear_util.py", line 234, in memoized_fun
    ans = call(fun, *args)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/interpreters/xla.py", line 595, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1023, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1004, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "mixed_documents/jax_error.py", line 11, in create_mask_inside_jitted_function
    return x[mask].sum()
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/core.py", line 505, in __getitem__
    def __getitem__(self, idx): return self.aval._getitem(self, idx)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 3783, in _rewriting_take
    treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 3843, in _split_index_for_jit
    idx = _expand_bool_indices(idx)
  File "/home/janos/anaconda3/envs/minimal_jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 4102, in _expand_bool_indices
    raise IndexError("Array boolean indices must be concrete.")
IndexError: Array boolean indices must be concrete.

Minimal Example

import jax.numpy as jnp
import jax

def pass_in_mask(x, y, mask):
    return x[mask].sum()

def create_mask_inside_jitted_function(x, y):
    mask = jnp.isfinite(y)
    return x[mask].sum()

x = jnp.array([1., 2, 3])
y = jnp.array([0, jnp.inf, jnp.nan])
mask = jnp.isfinite(y)

will_work = jax.jit(pass_in_mask, static_argnums=(1, 2))
wont_work = jax.jit(create_mask_inside_jitted_function, static_argnums=1)

will_work(x, y, mask)
wont_work(x, y)

Conda Environment

name: minimal_jax
channels:
  - conda-forge
dependencies:
  - python=3.8
  # - jax=0.1.77
  - jax=0.2.0
  - jaxlib=0.1.55
@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 7, 2020

This is working as intended: jax 0.2.0 includes a major change ("omnistaging") to how JAX stages out computations. There's a longer document coming, but #3370 has a description of the change.

Briefly, there are two basic things you could here:

  • if you want the mask to be resolved at compile time, use a static argument and use classic NumPy for the computation in question. This avoids staging the masking computation out to XLA.
  • if you want the mask to be resolved at run time, use the 3-argument form of jnp.where which avoids the need for concreteness of boolean indices.
jnp.where(jnp.isfinite(y), y, 0).sum()

Does that resolve the issue?

@hawkinsp hawkinsp added the question Questions for the JAX team label Oct 7, 2020
@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 7, 2020

@janosg
Copy link
Author

janosg commented Oct 7, 2020

Thanks a lot for the clarification and tips! Yes, this resolves the issue. Especially the design notes are very helpful.

@janosg janosg closed this as completed Oct 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants