Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.broadcast_to is incompatible with masking #2769

Closed
yingted opened this issue Apr 20, 2020 · 5 comments
Closed

jnp.broadcast_to is incompatible with masking #2769

yingted opened this issue Apr 20, 2020 · 5 comments

Comments

@yingted
Copy link
Contributor

yingted commented Apr 20, 2020

I want to call some vectorized functions from a masked function.

These functions commute with masking: f(x)[mask] == f(x[mask]) where mask = np.s_[:d0, :d1, ..., :dn]
But, I don't think JAX needs to know this.

Anyways, vectorize crashes when broadcasting a masked array:

import jax.numpy as jnp
import numpy as np

# Vectorized function, jnp.diag for simplicity:
diag3d = jnp.vectorize(jnp.diag, signature='(c,c)->(c)')
x = jnp.arange(8).reshape((2, 2, 2))
y = diag3d(x)
# Try calling it with masked inputs:
diag3d_masked = jax.mask(diag3d, in_shapes=['(t, c, c)'], out_shape='(t, c)')
np.testing.assert_allclose(diag3d_masked([x], dict(t=2, c=2)), y)

I get this error:

/usr/local/lib/python3.6/dist-packages/jax/numpy/vectorize.py in wrapped(*args)
    237       vec_shape = full_shape[-arg.ndim:] if arg.ndim else ()
    238 
--> 239       vec_arg = np.broadcast_to(arg, vec_shape)
    240       vec_args.append(vec_arg)
    241 

/usr/local/lib/python3.6/dist-packages/jax/numpy/lax_numpy.py in broadcast_to(arr, shape)
   1061   """Like Numpy's broadcast_to but doesn't necessarily return views."""
   1062   arr = arr if isinstance(arr, ndarray) else array(arr)
-> 1063   shape = tuple(map(int, shape))  # check that shape is concrete
   1064   arr_shape = _shape(arr)
   1065   if arr_shape == shape:

/usr/local/lib/python3.6/dist-packages/jax/interpreters/masking.py in __int__(self)
    209 
    210   def __int__(self):
--> 211     assert self.is_constant
    212 
    213     return int(next(iter(self.values())))

AssertionError: 
@shoyer
Copy link
Collaborator

shoyer commented Apr 20, 2020

It looks like may be more directly an incompatibility between broadcast_to and masking.

@jekbradbury jekbradbury changed the title jax.mask(jax.vectorize(func)) doesn't work jnp.broadcast_to is incompatible with masking Apr 21, 2020
@jekbradbury
Copy link
Contributor

I think that particular issue was one of the missing masking features added in the rolled-back #2017 (cc @j-towns), though there may be further issues after broadcast_to is fixed.

@j-towns
Copy link
Contributor

j-towns commented Apr 22, 2020

(ping @juliuskunze who wrote #2017) I think that's correct, as it happens @mattjj messaged me yesterday about fixing up and re-merging that pr, so we may be able to fix this issue quite quickly.

@juliuskunze
Copy link
Contributor

juliuskunze commented Apr 28, 2020

Once #2800 is merged, I can prepare another PR that should fix this issue + masking for most other missing primitives. The code is currently parked in https://github.com/JuliusKunze/jax/tree/masking-pcnn and needs some heavy updating to be compatible with the current jax version.

@hawkinsp
Copy link
Collaborator

jax.mask is long gone. Closing as stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants