Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

If you see this error, please let us know by opening an issue #2346

Closed
lukasheinrich opened this issue Mar 3, 2020 · 8 comments
Closed

If you see this error, please let us know by opening an issue #2346

lukasheinrich opened this issue Mar 3, 2020 · 8 comments
Assignees

Comments

@lukasheinrich
Copy link

in our code over at https://github.com/pyhf/neos we run into an issue when jit'ing functions that report

AssertionError: If you see this error, please let us know by opening an issue at
https://github.com/google/jax/issues 
since we thought this was unreachable!

we did manage to reach it, what would be the best way to report more detail?

cc @phinate

@shoyer
Copy link
Collaborator

shoyer commented Mar 4, 2020

This issue would be fine! But we’ll definitely need more detail, ideally with a minimal example to reproduce the issue and a stack trace.

@phinate
Copy link

phinate commented Mar 4, 2020

Hi @shoyer, I've done my best to rip the code from our neos library to reproduce the error.

The context of the error comes from trying to use autodiff in order to differentiate through a minimisation by gradient descent. The current way we do that in practice makes use of this implementation of Christianson's two-phase solver for fixed point differentiation.

We were trying to speed things up by calling jax.jit, but ran into this interesting error in the process. It arises when we use jax.jit after wrapping the minimiser in a function to take away the dependency of the init values required by the implementation within fax, but disappears if we also call jax.jit on the minimiser within that wrapping function.

As one may expect, this error is invariant to the details of the problem we were trying to solve, so the code example below just uses a dummy likelihood that returns 1.

import jax
from jax.experimental import optimizers
from fax.implicit import twophase


# doesn't matter what we return!
def log_likelihood(pars):
    return jax.numpy.ones(1,)[0]


def get_fit(
    default_rtol=1e-10,
    default_atol=1e-10,
    default_max_iter=int(1e7),
    learning_rate = 0.01
):

    adam_init, adam_update, adam_get_params  = optimizers.adam(1e-6)

    def global_bestfit_minimized(ignored_param):
        
        def bestfit_via_grad_descent(i, param):  # gradient descent
            g = jax.grad(log_likelihood)(param)
            param = adam_get_params(adam_update(i,g,adam_init(param)))
            return param

        return bestfit_via_grad_descent

    global_solve = twophase.two_phase_solver(
        param_func=global_bestfit_minimized,
        default_rtol=default_rtol,
        default_atol=default_atol,
        default_max_iter=default_max_iter
    )

    def global_fit(init, ignored_param):
        solve = global_solve(init, ignored_param)
        return solve.value

    return global_fit

def do_fit(ignored_param):
    fit = get_fit()
    
    # Commenting this line gives the error
    #fit = jax.jit(fit) 
    
    return fit(1.,ignored_param)

ignored_param = 1
jax.jit(do_fit)(ignored_param) # throws the error

I have a feeling this probably concerns details of the fax library, so I imagine the author has some valuable input (though I'm not able to tag them in this issue).

Hope this helps! :)

@phinate
Copy link

phinate commented Mar 4, 2020

And here's the long stack trace:

--------------------------------------------------------------------------
AssertionError                           Traceback (most recent call last)
<ipython-input-1-2ff1168f5a07> in <module>
     49 
     50 ignored_param = 1
---> 51 jax.jit(do_fit)(ignored_param) # throws the error

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/api.py in f_jitted(*args, **kwargs)
    144     flat_fun, out_tree = flatten_fun(f, in_tree)
    145     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 146                        name=flat_fun.__name__)
    147     return tree_unflatten(out_tree(), out)
    148 

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/core.py in call_bind(primitive, f, *args, **params)
    640   if top_trace is None:
    641     with new_sublevel():
--> 642       outs = primitive.impl(f, *args, **params)
    643   else:
    644     tracers = map(top_trace.full_raise, args)

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, *args)
    446 
    447 def _xla_call_impl(fun, *args, device, backend, name):
--> 448   compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
    449   try:
    450     return compiled_fun(*args)

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/linear_util.py in memoized_fun(fun, *args)
    218       fun.populate_stores(stores)
    219     else:
--> 220       ans = call(fun, *args)
    221       cache[key] = (ans, fun.stores)
    222     return ans

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, *arg_specs)
    463   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    464   with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 465     jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    466     assert not env  # no subtraces here
    467     del master, env

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    147     gen = None
    148 
--> 149     ans = self.f(*args, **dict(self.params, **kwargs))
    150     del args
    151     while stack:

<ipython-input-1-2ff1168f5a07> in do_fit(ignored_param)
     46     #fit = jax.jit(fit)
     47 
---> 48     return fit(1.,ignored_param)
     49 
     50 ignored_param = 1

<ipython-input-1-2ff1168f5a07> in global_fit(init, ignored_param)
     35 
     36     def global_fit(init, ignored_param):
---> 37         solve = global_solve(init, ignored_param)
     38         return solve.value
     39 

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/api.py in __call__(self, *args)
   1440     outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
   1441                           in_tree=in_tree, out_tree=out_tree(),
-> 1442                           num_consts=len(consts))
   1443     return tree_unflatten(out_tree(), outs)
   1444 

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/core.py in bind(self, *args, **kwargs)
    180 
    181     tracers = map(top_trace.full_raise, args)
--> 182     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    183     if self.multiple_results:
    184       return map(full_lower, out_tracer)

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
     96       return custom_partial_eval_rules[primitive](self, *tracers, **params)
     97     else:
---> 98       return self.default_process_primitive(primitive, tracers, params)
     99 
    100   def default_process_primitive(self, primitive, tracers, params):

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/partial_eval.py in default_process_primitive(self, primitive, tracers, params)
    104     tracers = map(self.instantiate_const, tracers)
    105     avals = [t.aval for t in tracers]
--> 106     out_aval = primitive.abstract_eval(*avals, **params)
    107     if primitive.multiple_results:
    108       out_tracers = [JaxprTracer(self, PartialVal((aval, unit)), None)

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/api.py in fun_abstract_eval(*avals, **params)
   1503 
   1504   def fun_abstract_eval(*avals, **params):
-> 1505     return pe.abstract_eval_fun(fun_impl, *avals, **params)
   1506   fun_p.def_abstract_eval(fun_abstract_eval)
   1507 

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/partial_eval.py in abstract_eval_fun(fun, *avals, **params)
    271   pvals_in = [PartialVal((a, unit)) for a in avals]
    272   _, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
--> 273                                   instantiate=True)
    274   avals_out, _ = unzip2(pvals_out)
    275   for aval_out in avals_out:

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out_calls)
    352   with new_master(trace_type) as master:
    353     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 354     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    355     assert not env
    356     del master

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    147     gen = None
    148 
--> 149     ans = self.f(*args, **dict(self.params, **kwargs))
    150     del args
    151     while stack:

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/api.py in fun_impl(*args, **params)
   1491   def fun_impl(*args, **params):
   1492     consts, args = split_list(args, [params['num_consts']])
-> 1493     return core.eval_jaxpr(params['jaxpr'], consts, *args)
   1494   fun_p.def_impl(fun_impl)
   1495 

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/core.py in eval_jaxpr(jaxpr, consts, *args)
    249     else:
    250       subfuns = []
--> 251     ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
    252     if eqn.primitive.multiple_results:
    253       map(write, eqn.outvars, ans)

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/core.py in bind(self, *args, **kwargs)
    180 
    181     tracers = map(top_trace.full_raise, args)
--> 182     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    183     if self.multiple_results:
    184       return map(full_lower, out_tracer)

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
     96       return custom_partial_eval_rules[primitive](self, *tracers, **params)
     97     else:
---> 98       return self.default_process_primitive(primitive, tracers, params)
     99 
    100   def default_process_primitive(self, primitive, tracers, params):

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/interpreters/partial_eval.py in default_process_primitive(self, primitive, tracers, params)
    104     tracers = map(self.instantiate_const, tracers)
    105     avals = [t.aval for t in tracers]
--> 106     out_aval = primitive.abstract_eval(*avals, **params)
    107     if primitive.multiple_results:
    108       out_tracers = [JaxprTracer(self, PartialVal((aval, unit)), None)

~/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/lax/lax.py in standard_abstract_eval(prim, shape_rule, dtype_rule, *args, **kwargs)
   1518            "https://github.com/google/jax/issues \n"
   1519            "since we thought this was unreachable!")
-> 1520     assert pe._thread_local_state.remat, msg
   1521     return ConcreteArray(prim.impl(*[x.val for x in args], **kwargs))
   1522   elif least_specialized is ShapedArray:

AssertionError: If you see this error, please let us know by opening an issue at
https://github.com/google/jax/issues 
since we thought this was unreachable!

@shoyer
Copy link
Collaborator

shoyer commented Mar 5, 2020

It looks like this code involves differentiating something defined with custom_gradient, so I'm guessing this is the same issue as #1875

We are about to overhaul the custom gradients machinery in #2026, and I'm optimistic that that will fix this, too

@gehring
Copy link
Contributor

gehring commented Mar 5, 2020

Very possibly unrelated but we've run into this error while writing fax but, if I recall correctly, there was a user error (on my part) which once fixed made the error disappear so we never thought of reporting it.

We haven't really encountered it since so I'm not sure what would cause this without diving deep into jax. If there is a simple workaround, I'm happy to update fax. Feel free to ping me!

@mattjj
Copy link
Collaborator

mattjj commented Mar 22, 2020

Sorry for the slow response.

The only place we've ever seen this error arise is in using jax.custom_transforms, and it's one of the reasons we are about to rip it out in #2026 (which is about to land).

I think #2026 fixes this: in fact, coincidentally (because I hadn't looked at this issue in detail until just now), in the tutorial notebook added by that PR, I included a reverse-mode differentiable fixed_point routine as an example. That sounds very similar to the routine in fax, so that's why I'm optimistic #2026 will help! (The implementation in the tutorial is based on the one I added to Autograd at one point.)

@mattjj
Copy link
Collaborator

mattjj commented Mar 22, 2020

I just merged #2026, so I believe this issue with custom_transforms is fixed (almost tautologically). Please take a look at the tutorial notebook for the new way to define custom JVPs / VJPs.

I'm going to close this specific issue but please open new ones as questions about the new API arise. I can also help update old custom_transforms code if needed.

@mattjj mattjj closed this as completed Mar 22, 2020
@gehring
Copy link
Contributor

gehring commented Mar 22, 2020

That sounds very similar to the routine in fax [...]

That is correct! That is essentially how it is done in fax.implicit.twophase except we use an arguably more bloated API...

I'm super excited to see that it supports grad(grad(...))! I tried to support it in fax but ran into lots of issues with the old custom_gradient. It's good to see that it this easy to pull off now. Looks like I have some refactoring to do now! Great job with the redesign @mattjj!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants