Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New to JAX, grad(f) throwing AssertionError with no message #2519

Closed
DanPuzzuoli opened this issue Mar 26, 2020 · 6 comments
Closed

New to JAX, grad(f) throwing AssertionError with no message #2519

DanPuzzuoli opened this issue Mar 26, 2020 · 6 comments
Labels
bug Something isn't working

Comments

@DanPuzzuoli
Copy link
Contributor

Hi,

I hope this is the correct place to post this: I'm new to JAX and am getting an error when trying to use grad. It seems like I am doing things correctly: the function g I define seems to be evaluating to the correct values, and when I call grad(g) no errors are thrown. However, if I try to evaluate grad(g) on an input, I get an AssertionError with no message. Not sure if I'm trying to use functions that are not automatically differentiable?

My code is:

import jax.numpy as np
from jax import grad
from jax.scipy.linalg import expm
from jax.config import config
config.update("jax_enable_x64", True)

# define some complex matrices
X = -1j*np.array([[0, 1], [1, 0]], dtype=complex)
Y = -1j*np.array([[0, -1j], [1j, 0]], dtype=complex)
Z = -1j*np.array([[1.,0],[0,-1]], dtype=complex)

# ##############
# construct function
# ##############

# computes hilbert-schmidt inner product between matrices
def hs_ip(x,y):
    return (x.conj().transpose() @ y).trace()

# computes |x|**2/4, for x = hs_ip(U,x), with U being a particular matrix defined below
U = expm((X+Y)/np.sqrt(2))
def f(x):
    y = hs_ip(U,x)
    return (y*y.conj()).real/4

# the function to be differentiated (takes in real array, returns real scalar)
def g(a):
    return f(expm(a[0]*X + a[1]*Y + a[2]*Z))

Calling g returns correct output (expected output is 1):

print(g(np.array([1,1,0], dtype=float)/np.sqrt(2)))
1.0000000000000004

Attempting to call grad(g) (expected output is 0):

grad(g)( np.array([1,1,0], dtype=float)/np.sqrt(2) )

throws error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-14-a0a3eafaec5b> in <module>
----> 1 grad(g)( np.array([1,1,0], dtype=float)/np.sqrt(2) )

~/anaconda3/envs/QiskitDev/lib/python3.7/site-packages/jax/api.py in grad_f(*args, **kwargs)
    360   @wraps(fun, docstr=docstr, argnums=argnums)
    361   def grad_f(*args, **kwargs):
--> 362     _, g = value_and_grad_f(*args, **kwargs)
    363     return g
    364 

~/anaconda3/envs/QiskitDev/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    416     f_partial, dyn_args = _argnums_partial(f, argnums, args)
    417     if not has_aux:
--> 418       ans, vjp_py = _vjp(f_partial, *dyn_args)
    419     else:
    420       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/anaconda3/envs/QiskitDev/lib/python3.7/site-packages/jax/api.py in _vjp(fun, *primals, **kwargs)
   1334   if not has_aux:
   1335     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1336     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1337     out_tree = out_tree()
   1338   else:

~/anaconda3/envs/QiskitDev/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    104 def vjp(traceable, primals, has_aux=False):
    105   if not has_aux:
--> 106     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    107   else:
    108     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/anaconda3/envs/QiskitDev/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     96   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     97   aval_primals, const_primals = unzip2(pval_primals)
---> 98   assert all(aval_primal is None for aval_primal in aval_primals)
     99   if not has_aux:
    100     return const_primals, pval_tangents, jaxpr, consts

AssertionError: 

Note I have not missed anything, this is the full output (there is no message with the AssertionError).

Thank you!

@mattjj
Copy link
Collaborator

mattjj commented Mar 26, 2020

Thanks for raising this! This is a perfect place to post :)

I think expm differentiation is a WIP (see #2062, #1635, #1940) and unfortunately its current state raises this error. The place to track progress is #2062 but that's stalled out, probably because we JAX devs have dropped the ball on it.

Could you check out the code in #2062 and see if your code works there?

@mattjj mattjj added the bug Something isn't working label Mar 26, 2020
@DanPuzzuoli
Copy link
Contributor Author

Thanks for the response!

I've checked out the code from #2062 . Not sure if I'm making a mistake installing things; I've checked out the repo, and done:

pip install jaxlib
pip install -e .

while in the folder I've checked out the code.

Now when I do:

X = -1j*np.array([[0, 1], [1, 0]], dtype=complex)
Y = -1j*np.array([[0, -1j], [1j, 0]], dtype=complex)
Z = -1j*np.array([[1.,0],[0,-1]], dtype=complex)

I get the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-903db982c4e5> in <module>
----> 1 X = -1j*np.array([[0, 1], [1, 0]], dtype=complex)
      2 Y = -1j*np.array([[0, -1j], [1j, 0]], dtype=complex)
      3 Z = -1j*np.array([[1.,0],[0,-1]], dtype=complex)

~/Documents/projects/jax/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   1839   elif isinstance(object, (list, tuple)):
   1840     if object:
-> 1841       out = stack([array(elt, dtype=dtype) for elt in object])
   1842     else:
   1843       out = onp.array([], dtype or float_)

~/Documents/projects/jax/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1839   elif isinstance(object, (list, tuple)):
   1840     if object:
-> 1841       out = stack([array(elt, dtype=dtype) for elt in object])
   1842     else:
   1843       out = onp.array([], dtype or float_)

~/Documents/projects/jax/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   1839   elif isinstance(object, (list, tuple)):
   1840     if object:
-> 1841       out = stack([array(elt, dtype=dtype) for elt in object])
   1842     else:
   1843       out = onp.array([], dtype or float_)

~/Documents/projects/jax/jax/numpy/lax_numpy.py in <listcomp>(.0)
   1839   elif isinstance(object, (list, tuple)):
   1840     if object:
-> 1841       out = stack([array(elt, dtype=dtype) for elt in object])
   1842     else:
   1843       out = onp.array([], dtype or float_)

~/Documents/projects/jax/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   1831       out = device_put(object)
   1832   elif isscalar(object):
-> 1833     out = lax.reshape(object, ())
   1834     if dtype and _dtype(out) != dtypes.canonicalize_dtype(dtype):
   1835       out = lax.convert_element_type(out, dtype)

~/Documents/projects/jax/jax/lax/lax.py in reshape(operand, new_sizes, dimensions)
    638         operand, new_sizes=new_sizes,
    639         dimensions=None if same_dims else tuple(dimensions),
--> 640         old_sizes=onp.shape(operand))
    641 
    642 def pad(operand, padding_value, padding_config):

~/Documents/projects/jax/jax/core.py in bind(self, *args, **kwargs)
    177     top_trace = find_top_trace(args)
    178     if top_trace is None:
--> 179       return self.impl(*args, **kwargs)
    180 
    181     tracers = map(top_trace.full_raise, args)

~/Documents/projects/jax/jax/lax/lax.py in _reshape_impl(operand, new_sizes, dimensions, old_sizes)
   2562   else:
   2563     return xla.apply_primitive(reshape_p, operand, new_sizes=new_sizes,
-> 2564                                dimensions=dimensions, old_sizes=old_sizes)
   2565 
   2566 def _is_singleton_reshape(old, new):

~/Documents/projects/jax/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    157   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
    158   compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
--> 159   return compiled_fun(*args)
    160 
    161 @cache()

~/Documents/projects/jax/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, backend, tuple_args, result_handler, *args)
    244   if tuple_args:
    245     input_bufs = [make_tuple(input_bufs, device, backend)]
--> 246   out_buf = compiled.Execute(input_bufs)
    247   if FLAGS.jax_debug_nans:
    248     check_nans(prim, out_buf.destructure() if prim.multiple_results else out_buf)

TypeError: Execute(): incompatible function arguments. The following argument types are supported:
    1. (self: jaxlib.xla_extension.LocalExecutable, arguments: Span[jaxlib.xla_extension.PyLocalBuffer], tuple_arguments: bool) -> StatusOr[List[jaxlib.xla_extension.PyLocalBuffer]]

Invoked with: <jaxlib.xla_extension.LocalExecutable object at 0x7f86a8049970>, [<jaxlib.xla_extension.PyLocalBuffer object at 0x7f86f9011b70>]

@hawkinsp
Copy link
Collaborator

What version of jax and jaxlib do you have installed? If I had to guess, your jaxlib isn't the latest. Can you verify it's version 0.1.42?

@mattjj
Copy link
Collaborator

mattjj commented Mar 26, 2020

That's my guess too. pip install --upgrade jaxlib might be better, since without the --upgrade pip will just maintain an old one.

@DanPuzzuoli
Copy link
Contributor Author

Yep I have the latest one. I just tried creating a new conda environment and installing, but am getting the same error. The steps I did were:

  • created and activated new conda env
  • Commands (in the folder with the branch from Issue1635 expm frechet #2062):
    conda install pip
    pip install --upgrade jaxlib
    pip install -e .
    
  • I then run the code in a jupyter notebook

When jaxlib installs it says its version 0.1.42 (also verified with jaxlib.__version__ in the jupyter notebook). Note sure if I'm making some naive mistake somewhere.

@DanPuzzuoli
Copy link
Contributor Author

I think I'll close this for now as I believe I may be able to work my way around this. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants