-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A case in which jax.jit
changes results, prevents backprop, and inhibits optimisations
#9298
Comments
According to my very shallow understanding of JAX, the code in both examples you gave are working as intended™. One should notice I think I've seen discussions about introducing a |
If this is intended then it's an "as intended" that only has downsides: having a sub-jit is only ever a Bad Thing™. We still have to re-trace, but now also have these cast-to-tracers. It would make more sense to either (a) fully disable the sub-jits (and thus improve compile-time evaluation), or (b) fully enable the sub-jits (and then avoid re-tracing, re-compiling etc.) |
Thanks for the excellent report as always, @patrick-kidger ! Indeed this is working as intended. Yes it has downsides. But the upsides outweigh them. Let me explain.
Just to keep our terminology straight: I think you're referring to specializing on constants at tracing/staging time. I'd refer to compile-time as the time when we run a compiler like XLA. In particular, in your first example the literal Interestingly, before #3370, JAX had the behavior you're asking for! But we put in a lot of work to change it. That's because, as described on that PR in more detail, JAX was doing too much specialization at trace time. It specialized (trace-time-evaluated) everything it could! But that led to bad behavior, like How is one to know what's worth constant-folding and what's not? It's a tricky compiler question (unless guided explicitly by an expert user; more on that later). We don't like to take on tricky compiler questions; that's what we rely on XLA for. So we stage out as much as possible to XLA and let it figure out what to constant-fold. A downside of that choice is that we can't do specialization at trace-time.
Unfortunately this has costs too: it means we'd always be inlining all functions. That'd basically make our planned compile time improvements impossible. Analogously to constant folding, we'd rather give a proper compiler all the information it needs to decide whether to inline a function call. In particular, that means staging out functions rather than inlining them ourselves.
One last narrow technical point, before moving on to talk about solutions: sub-jit functions aren't different from any other jit function here. That is, this fails too, with no sub-jits in sight: import jax
import jax.lax as lax
@jax.jit
def f(init, lower):
def _body_fun(_, _y):
return _y + 1
return lax.fori_loop(lower, 1, _body_fun, init)
# @jax.jit <-- commented out!
def g():
return jax.grad(f)(0.0, 0)
g() # still errors! I think the behavior of In summary, I'd say that a big upside to the " But even if it's the best starting point, that doesn't mean we can't do more to enable expert library authors. For example, if you'd like your library to be able to specialize on values known at trace time, we'd love to help give you tools to make that possible! (But they might be opt-in tools using not-promised-to-be-stable APIs rather than changes to e.g. If you're sure you want a Separately, this isn't quite what you asked for, but there happens to already be an easy way to give the caller the power to inline jits: import jax
import jax.lax as lax
from jax.experimental.callback import callback_transform
def inline_inner_jits_at_trace_time(f):
def trivial(prim, vals, params):
return prim.bind(*vals, **params)
return callback_transform(f, trivial, strip_calls=True)
@jax.jit
def f(init, lower):
def _body_fun(_, _y):
return _y + 1
return lax.fori_loop(lower, 1, _body_fun, init)
@jax.jit
def g():
f_ = inline_inner_jits_at_trace_time(f)
return jax.grad(f_)(0.0, 0)
g() Since the caller knows they're applying their own |
The |
One more thing to push back on: if this were the case, we wouldn't put |
I couldn't remember why I made While we've got our fingers crossed, though, on #9342 this works: from functools import partial
import jax
import jax.lax as lax
@partial(jax.jit, inline=True)
def f(init, lower):
def _body_fun(_, _y):
return _y + 1
return lax.fori_loop(lower, 1, _body_fun, init)
@jax.jit
def g():
return jax.grad(f)(0.0, 0)
g() It's doing the thing that I believe you expected from the start: if you run But this implementation still has the weird behavior described above: if you remove the |
@mattjj I think code like this just shouldn't (modulo the ongoing dynamic shape work, of course) work without proper annotation wrt. static arguments. They fail now because Currently, the semantics of And actually the reason |
Thanks for the (very long!) response. I agree with most of what you've said, so to keep things brief I'll just highlight a few points I'm curious about!
Ah, so my suggestion was predicated on the understanding that this was already the case, and that everything was inlined during tracing anyway. From what you say (and looking at #9181, thanks for the link) it sounds like this is indeed currently the case -- but this will soon be changing? And presumably, What does
So my first (easy!) suggestion would be to offer a public
It sounds like this is what #9342 is doing; agreed that this would also be a great tool.
I don't understand what you mean by this. I'm aware that everything the
It sounds like we're stuck with some mystery either way... I think I'd rather have things mysteriously working than mysteriously breaking! And finally:
I'm pretty sure JAX checks all of these boxes:
Dear Sir, you have built a compiler. ;) |
@patrick-kidger |
I don't think JAX has a parser or a code generator though! (From Wikipedia: "into a form (e.g., machine code) that can be readily executed by a machine.")
@soraros hrm, on what basis are you saying something should or shouldn't fail though? I don't think "type-level variable" is well-defined in Python or JAX. I think we're free to make up our own rules so long as there aren't contradictions. Are you referring to some other problem other than the "grad-of-jit doesn't work, yet jit-of-grad-of-jit does?" issue? I agree that's weird. But I think we're only talking about which programs raise errors here; that is, there aren't two valid (non-error) behaviors we're deciding between as far as I know. So why not take the stance that we should just make as few programs raise errors as possible?
Thanks for the link! But I don't grok my explanation there... in particular, running the WrappedFun without an incremented sublevel seems fine now (it must've been |
@mattjj I meant all the errors we see in the thread are just more obscure versions of the famous def f(init, lower):
def body_fun(_, x):
return x + 1
return lax.fori_loop(lower, 1, body_fun, init) should really be typed as f :: (..., Concretizable b) => a -> b -> a The only way for us to mimic this in python/JAX is to annotate it with And here is a bug(?) I found regarding #9342. Consider the following code: from functools import partial
from jax import jit
# @jit
# @partial(jit, inline=True)
def f(init, lower):
if lower < 1:
return 0.
else:
return g(init, lower - 1)
# @jit
# @partial(jit, inline=True)
def g(init, lower):
if lower < 1:
return 0.
else:
return f(init, lower - 1)
# @jit
def h():
return f(0., 2)
h()
|
@soraros I don't think I agree. In terms of type-annotations, the point is that we want to express overloads of a function. Using (Note that there's no real way to express this in Haskell-like notation, since Haskell has sum types instead of union types.) As it stands there is no way to wrap |
I agree, yet I think traced numerical arguments are just less restrictive than static ones. So fori_loop :: (Integral i, Concretizable i, Integral j) =>
i -> i -> ((j, a) -> a) -> a -> a assuming instance Integral Int where ...
instance Integral TracedInt where ...
instance Concretizable Int where
concretize = id Maybe my understanding is fundamentally flawed, but could you give an example where |
I disagree with this. Traced arguments have advantages (no recompilation) and disadvantages (non-backpropagation through an iteration of unknown length); static arguments have advantages (backprop through iterations; more specialised/efficient compiled code) and disadvantages (recompilation).
This doesn't take into account the possibility of traced
Consider running an RNN over a sequence whose length isn't known until runtime. As a more advanced example in the same spirit: considering writing an ODE solver. Sometimes you take fixed steps and the number of steps is known in advance. Sometimes you take adaptive steps and the number of steps is not known until run time. It can be desirable to produce specialised code for each case, whilst handling both cases through a single interface, much like |
@patrick-kidger Ahh, my bad, sorry for being slow on your points. Just realized that sometimes
|
@soraros whoa, any idea why your example (5) diverges? I ran it and I couldn't immediately tell what was going on. (Thanks for finding the error message bug too!) I should've mentioned earlier: Another thing I should've mentioned earlier: |
@mattjj If you disable from functools import partial
import jax
from jax import jit
jax.config.update('jax_platform_name', 'cpu')
@partial(jit, inline=True)
def f(rec):
if rec:
return f(False)
else:
return 0
@jit
def h():
return f(True)
h() Somehow the inner call, namely |
Thanks for sharing that finding. There's a recently-uncovered bug in the C++ jit dispatch path which might explain the deadlock. I've lost the plot on this issue thread. I could page it back in but I'm hoping that you folks can help me avoid that... @patrick-kidger and @soraros is the conclusion here that we should just close the issue? Or land #9342? Or figure out something else to do? |
I think land #9342 for certain. Once that's in then I dont think there's ever a reason to use |
This should be more beneficial, as followed by the discussions in Jax repository, see jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342
This should be more beneficial, as followed by the discussions in Jax repository, see jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342
* refactor(jit): make all jitted functions "inline=True" This should be more beneficial, as followed by the discussions in Jax repository, see jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342 * perf(pipeline): try to render 4/2/1 rows per batch using vmap to reduce fori_loop iterations * feat(_meta_utils): simple way to add multiple trace annotations together for functions add `@ad_tracing_name` to most functions to assist profiling also bump to Python 3.10 BREAKING CHANGE: Now requires Python 3.10 * perf(pipeline): big refactor to not updating per rows, but renders all rows then concat and merge * perf(pipeline): using scan + unroll (equiv map + unroll) This is very similar to map + vmap (minibatch processing) as the inner loop is too complex * build(pyproject): try to relax jax{lib}'s verion constraint to ">=0.3.25,<5.0.0" * test(pre-gen-brax): example inputs for profiling * perf: try to eliminate all `lax.cond` under `vmap`, `lax.cond` are lowered to `select_n` in HLO which leads to execution in both branches, thus fails to 1) save computation when possible; 2) prevent unexpected values to be produced/unexpected branches to be executed (defensive), thus let the non-dummy branch to be executed anyway and only rule-out garbage value at the final stage all together to try to improve performance. See google/brax#8409 for more details about unconditional executation of cond under vmap * fix(pipeline): gl_FrontFacing: fix its determination in pipeline `True` if NOT back-facing * perf: added extra stage in pipeline, aiming to interpolate and shade only one fragment per pixel * docs(changelog): expose option `loop_unroll`; dependency version change Bump minimum Python version from 3.9 to 3.10; lower minimum jax & jaxlib to 0.3.25. * build(pyproject): bump to 0.3.0
The summary is that I've found an (edge) case in which:
jax.jit
decorator can change the functional behaviour of the program;jax.jit
decorator can prohibit reverse-mode autodifferentiation.jax.jit
decorator can inhibit the use of compile-time optimisations;In all cases it is due to the same single root cause: a nested
jax.jit
unnecessarily converts concrete values to tracers.First a MWE of the root cause:
This program prints
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/2)>
: the1
is unnecessarily converted into a tracer.At first glance that probably seems reasonable. Why is this undesirable? Consider the following case.
adding a
jax.jit
can change the functional behaviour of the programIf
f
is decorated withjax.jit
then the above program crashes.If
f
is undecorated then the program runs.adding a
jax.jit
can inhibit reverse-mode autodifferentiation(Same as above:
If
f
is decorated withjax.jit
then the above program crashes.If
f
is undecorated then the program runs.)adding a
jax.jit
can inhibit the use of compile-time optimisationsAs a library author interested in optimising compile time and run times, then I'd like to specialise behaviour based on compile-time values.
In particular I have a case in which if a value is known at compile time then I can produce code that is efficient to compile. If it is only known at run time then the extra generality requires a more complicated program -- that produces the same functional results -- but is several orders of magnitude slower to compile.
(So morally speaking something a little similar to
lax.fori_loop
in native JAX.)Resolution: I think the resolution should be that within the dynamic context of a
jax.jit
decorator, then all otherjax.jit
decorators should be disabled.C.f. #7155: sub-jit-functions are retraced within the dynamic context of a
jax.jit
decorator.Overall, sub-jit-functions currently enter a weird state of being only "partially disabled" within the dynamic context of a
jax.jit
decorator: they are retraced, but values are still unnecessarily promoted to tracers.The text was updated successfully, but these errors were encountered: