Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A case in which jax.jit changes results, prevents backprop, and inhibits optimisations #9298

Open
patrick-kidger opened this issue Jan 24, 2022 · 19 comments · May be fixed by #9342
Open

A case in which jax.jit changes results, prevents backprop, and inhibits optimisations #9298

patrick-kidger opened this issue Jan 24, 2022 · 19 comments · May be fixed by #9342
Assignees
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Jan 24, 2022

The summary is that I've found an (edge) case in which:

  • adding a jax.jit decorator can change the functional behaviour of the program;
  • adding a jax.jit decorator can prohibit reverse-mode autodifferentiation.
  • adding a jax.jit decorator can inhibit the use of compile-time optimisations;

In all cases it is due to the same single root cause: a nested jax.jit unnecessarily converts concrete values to tracers.


First a MWE of the root cause:

import jax

@jax.jit
def f(x):
    print(x)

@jax.jit
def g():
    f(1)

g()

This program prints Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/2)>: the 1 is unnecessarily converted into a tracer.

At first glance that probably seems reasonable. Why is this undesirable? Consider the following case.

import jax
import jax.lax as lax

@jax.jit
def f(init, lower):
    def _body_fun(_, _y):
        return _y + 1
    return lax.fori_loop(lower, 1, _body_fun, init)

@jax.jit
def g():
    return jax.grad(f)(0.0, 0)

g()

adding a jax.jit can change the functional behaviour of the program

If f is decorated with jax.jit then the above program crashes.
If f is undecorated then the program runs.

adding a jax.jit can inhibit reverse-mode autodifferentiation

(Same as above:
If f is decorated with jax.jit then the above program crashes.
If f is undecorated then the program runs.)

adding a jax.jit can inhibit the use of compile-time optimisations

As a library author interested in optimising compile time and run times, then I'd like to specialise behaviour based on compile-time values.

In particular I have a case in which if a value is known at compile time then I can produce code that is efficient to compile. If it is only known at run time then the extra generality requires a more complicated program -- that produces the same functional results -- but is several orders of magnitude slower to compile.

(So morally speaking something a little similar to lax.fori_loop in native JAX.)


Resolution: I think the resolution should be that within the dynamic context of a jax.jit decorator, then all other jax.jit decorators should be disabled.


C.f. #7155: sub-jit-functions are retraced within the dynamic context of a jax.jit decorator.

Overall, sub-jit-functions currently enter a weird state of being only "partially disabled" within the dynamic context of a jax.jit decorator: they are retraced, but values are still unnecessarily promoted to tracers.

@patrick-kidger patrick-kidger added the bug Something isn't working label Jan 24, 2022
@soraros
Copy link

soraros commented Jan 24, 2022

According to my very shallow understanding of JAX, the code in both examples you gave are working as intended™. One should notice jax.jit also serves as an API tool, as it is the only way (that I know of) to marks which arguments are static for a certain function. So a value passed to a jitted function, which is treated as nonstatic if not explicitly marked otherwise, should be wrapped in a tracer.

I think I've seen discussions about introducing a static wrapper so one can make the decision about which argument is static at the call site, but I failed to find that thread.

@patrick-kidger
Copy link
Collaborator Author

If this is intended then it's an "as intended" that only has downsides: having a sub-jit is only ever a Bad Thing™. We still have to re-trace, but now also have these cast-to-tracers. It would make more sense to either (a) fully disable the sub-jits (and thus improve compile-time evaluation), or (b) fully enable the sub-jits (and then avoid re-tracing, re-compiling etc.)

@mattjj
Copy link
Collaborator

mattjj commented Jan 26, 2022

Thanks for the excellent report as always, @patrick-kidger !

Indeed this is working as intended. Yes it has downsides. But the upsides outweigh them. Let me explain.

As a library author interested in optimising compile time and run times, then I'd like to specialise behaviour based on compile-time values. In particular I have a case in which if a value is known at compile time then I can produce code that is efficient to compile.

Just to keep our terminology straight: I think you're referring to specializing on constants at tracing/staging time. I'd refer to compile-time as the time when we run a compiler like XLA. In particular, in your first example the literal 1 and in your second example the literals 0.0 and 0 are both given as compile-time constants to XLA, so they are constants at what I'm proposing we refer to as compile-time. But they're not constants at JAX's trace/staging-time.

Interestingly, before #3370, JAX had the behavior you're asking for! But we put in a lot of work to change it. That's because, as described on that PR in more detail, JAX was doing too much specialization at trace time. It specialized (trace-time-evaluated) everything it could! But that led to bad behavior, like random.normal(random.PRNGKey(0), (5000, 5000)) being specialized/evaluated into a big array constant during tracing.

How is one to know what's worth constant-folding and what's not? It's a tricky compiler question (unless guided explicitly by an expert user; more on that later). We don't like to take on tricky compiler questions; that's what we rely on XLA for. So we stage out as much as possible to XLA and let it figure out what to constant-fold.

A downside of that choice is that we can't do specialization at trace-time.

Resolution: I think the resolution should be that within the dynamic context of a jax.jit decorator, then all other jax.jit decorators should be disabled.

Unfortunately this has costs too: it means we'd always be inlining all functions. That'd basically make our planned compile time improvements impossible. Analogously to constant folding, we'd rather give a proper compiler all the information it needs to decide whether to inline a function call. In particular, that means staging out functions rather than inlining them ourselves.

Overall, sub-jit-functions currently enter a weird state of being only "partially disabled" within the dynamic context of a jax.jit decorator: they are retraced, but values are still unnecessarily promoted to tracers.

One last narrow technical point, before moving on to talk about solutions: sub-jit functions aren't different from any other jit function here. That is, this fails too, with no sub-jits in sight:

import jax
import jax.lax as lax

@jax.jit
def f(init, lower):
    def _body_fun(_, _y):
        return _y + 1
    return lax.fori_loop(lower, 1, _body_fun, init)

# @jax.jit <-- commented out!
def g():
    return jax.grad(f)(0.0, 0)

g()  # still errors!

I think the behavior of jit here is currently self-consistent, and if we were to "disable all-subjits" that would actually make for complex behavior, where e.g. in the above example jax.grad(f)(0.0, 0) only works if there's an outer jit.


In summary, I'd say that a big upside to the "jit stages out everything" approach is that it's simple: compared to automatically deciding what values to constant fold at tracing/staging time and which not to (or which functions to trace-time-inline), it's simpler for us to implement and maintain, and it's simpler for users to predict jit's behavior. For that reason I think it's the best starting point.

But even if it's the best starting point, that doesn't mean we can't do more to enable expert library authors. For example, if you'd like your library to be able to specialize on values known at trace time, we'd love to help give you tools to make that possible! (But they might be opt-in tools using not-promised-to-be-stable APIs rather than changes to e.g. jit's default behavior though.)

If you're sure you want a jit that disappears when there's an outer jit (and hence which will sometimes succeed only when there's an outer jit, as in your example), we can try to put something like that together. What do you think?

Separately, this isn't quite what you asked for, but there happens to already be an easy way to give the caller the power to inline jits:

import jax
import jax.lax as lax
from jax.experimental.callback import callback_transform

def inline_inner_jits_at_trace_time(f):
    def trivial(prim, vals, params):
        return prim.bind(*vals, **params)
    return callback_transform(f, trivial, strip_calls=True)

@jax.jit
def f(init, lower):
    def _body_fun(_, _y):
        return _y + 1
    return lax.fori_loop(lower, 1, _body_fun, init)

@jax.jit
def g():
    f_ = inline_inner_jits_at_trace_time(f)
    return jax.grad(f_)(0.0, 0)

g()

Since the caller knows they're applying their own jit to g, they can use that transformation to trace-time-inline all jits in the application of f.

@mattjj
Copy link
Collaborator

mattjj commented Jan 26, 2022

The inline option already on jit is almost what you want to use here, but for a technical reason I can't remember I wrote that to first trace to a jaxpr, so we're not really doing trace-time specialization for the called function.

@mattjj
Copy link
Collaborator

mattjj commented Jan 26, 2022

If this is intended then it's an "as intended" that only has downsides: having a sub-jit is only ever a Bad Thing™.

One more thing to push back on: if this were the case, we wouldn't put jit on all our jax.numpy functions! (But we do, since that's what lets us provide a JIT-compiled NumPy, making eager execution faster. Other libraries use jit for the same purpose.)

mattjj added a commit to mattjj/jax that referenced this issue Jan 26, 2022
@mattjj mattjj linked a pull request Jan 26, 2022 that will close this issue
@mattjj
Copy link
Collaborator

mattjj commented Jan 26, 2022

I couldn't remember why I made jit(inline=True) build a jaxpr first. So in #9342 I tried adjusting it, and the tests I've tried so far all seem to work... (Still, I've got a feeling we're going to see a test failure which will illuminate why I didn't take this approach the first time!)

While we've got our fingers crossed, though, on #9342 this works:

from functools import partial
import jax
import jax.lax as lax

@partial(jax.jit, inline=True)
def f(init, lower):
    def _body_fun(_, _y):
        return _y + 1
    return lax.fori_loop(lower, 1, _body_fun, init)

@jax.jit
def g():
    return jax.grad(f)(0.0, 0)

g()

It's doing the thing that I believe you expected from the start: if you run print(jax.make_jaxpr(g)()), it'll be as if the inner jit just isn't there at all.

But this implementation still has the weird behavior described above: if you remove the jax.jit from g, then it fails! That is, as before, grad-of-jit fails on f, but jit-of-grad-of-jit mysteriously works!

@soraros
Copy link

soraros commented Jan 27, 2022

@partial(jax.jit, inline=True)
def f(init, lower):
    def _body_fun(_, _y):
        return _y + 1
    return lax.fori_loop(lower, 1, _body_fun, init)

@jax.jit
def g():
    return jax.grad(f)(0.0, 0)

g()

@mattjj I think code like this just shouldn't (modulo the ongoing dynamic shape work, of course) work without proper annotation wrt. static arguments. They fail now because lower really is a type-level variable, making it traceable (which a naked @jit did) will violate the static shape requirement. The proper 'fix' should be using partial(jit, static_argnums=(1,)).

Currently, the semantics of jit is quite simple: every argument of the wrapped function is traced (wrapped in a Tracer, tagged, part of the embedding, etc. whatever you call it) unless stated otherwise in static_argnums or static_argnames (then it's not traced, work on the meta/host language level), and the inline keyword removes one layer of xla_call. I would vote against #9342 as it hurts such a simple model.

And actually the reason inline traced the function first into jaxpr is explained in your very well written PR message here. ;)

@patrick-kidger
Copy link
Collaborator Author

Thanks for the (very long!) response.

I agree with most of what you've said, so to keep things brief I'll just highlight a few points I'm curious about!

Unfortunately this has costs too: it means we'd always be inlining all functions. That'd basically make our planned compile time improvements impossible. Analogously to constant folding, we'd rather give a proper compiler all the information it needs to decide whether to inline a function call. In particular, that means staging out functions rather than inlining them ourselves.

Ah, so my suggestion was predicated on the understanding that this was already the case, and that everything was inlined during tracing anyway. From what you say (and looking at #9181, thanks for the link) it sounds like this is indeed currently the case -- but this will soon be changing? And presumably, jit(inline=True) will be the toggle that disables this new behaviour?

What does jit(inline=True) do at the moment? The docs are a bit sparse. I can see that one obtains slightly different jaxprs but my belief was that this didn't actually change anything, practically speaking.

But even if it's the best starting point, that doesn't mean we can't do more to enable expert library authors. For example, if you'd like your library to be able to specialize on values known at trace time, we'd love to help give you tools to make that possible! (But they might be opt-in tools using not-promised-to-be-stable APIs rather than changes to e.g. jit's default behavior though.)

So my first (easy!) suggestion would be to offer a public jax.is_jit_enabled() function to determine whether or not JIT'ing is currently occurring. (I suppose one could probably already hack this via isinstance(jnp.array(1) + 1, jax.core.Tracer)...) Then it'd be possible to write a wrapper for a maybe-inner-jit'd function that pipes arguments through a jit-trace or jit-static as desired.

If you're sure you want a jit that disappears when there's an outer jit (and hence which will sometimes succeed only when there's an outer jit, as in your example), we can try to put something like that together. What do you think?

It sounds like this is what #9342 is doing; agreed that this would also be a great tool.

One more thing to push back on: if this were the case, we wouldn't put jit on all our jax.numpy functions! (But we do, since that's what lets us provide a JIT-compiled NumPy, making eager execution faster. Other libraries use jit for the same purpose.)

I don't understand what you mean by this. I'm aware that everything the jnp namespace has a jax.jit wrapper for the purposes of eager execution, but I don't think anything I was suggesting would have impacted this i..e it would have been reasonable to just leave all of those in-place.

But this implementation still has the weird behavior described above: if you remove the jax.jit from g, then it fails! That is, as before, grad-of-jit fails on f, but jit-of-grad-of-jit mysteriously works!

It sounds like we're stuck with some mystery either way... I think I'd rather have things mysteriously working than mysteriously breaking!


And finally:

How is one to know what's worth constant-folding and what's not? It's a tricky compiler question (unless guided explicitly by an expert user; more on that later). We don't like to take on tricky compiler questions; that's what we rely on XLA for. So we stage out as much as possible to XLA and let it figure out what to constant-fold.

I'm pretty sure JAX checks all of these boxes:

  • a parser,
  • an intermediate representation,
  • transformation passes,
  • and a code generator

Dear Sir, you have built a compiler.

;)

@soraros
Copy link

soraros commented Jan 27, 2022

@patrick-kidger inline removes one layer of xla_call, ref.

@mattjj
Copy link
Collaborator

mattjj commented Jan 27, 2022

I'm pretty sure JAX checks all of these boxes:

I don't think JAX has a parser or a code generator though! (From Wikipedia: "into a form (e.g., machine code) that can be readily executed by a machine.")

I think code like this just shouldn't (modulo the ongoing dynamic shape work, of course) work without proper annotation wrt. static arguments. They fail now because lower really is a type-level variable, making it traceable (which a naked @jit did) will violate the static shape requirement.

@soraros hrm, on what basis are you saying something should or shouldn't fail though? I don't think "type-level variable" is well-defined in Python or JAX. I think we're free to make up our own rules so long as there aren't contradictions.

Are you referring to some other problem other than the "grad-of-jit doesn't work, yet jit-of-grad-of-jit does?" issue? I agree that's weird. But I think we're only talking about which programs raise errors here; that is, there aren't two valid (non-error) behaviors we're deciding between as far as I know. So why not take the stance that we should just make as few programs raise errors as possible?

And actually the reason inline traced the function first into jaxpr is explained in your very well written PR message here. ;)

Thanks for the link! But I don't grok my explanation there... in particular, running the WrappedFun without an incremented sublevel seems fine now (it must've been core.process_env_traces that I was referring to there, but I don't see the issue). And if all our tests pass (including a gazillion google-internal ones, which I have yet to check), no one can complain, right? (Maybe some other refactor in the interim has made this possible?)

@soraros
Copy link

soraros commented Jan 27, 2022

@mattjj I meant all the errors we see in the thread are just more obscure versions of the famous ConcretizationTypeError. A function written as

def f(init, lower):
  def body_fun(_, x):
    return x + 1
  return lax.fori_loop(lower, 1, body_fun, init)

should really be typed as

f :: (..., Concretizable b) => a -> b -> a

The only way for us to mimic this in python/JAX is to annotate it with partial(jit, static_argnums=(1,)). Seeing them as incorrectly typed, I argued "they shouldn't run". Even given the nicety in #9342, they are still under-typed, problems can pop up elsewhere. Just imagine replacing all the static_argnums=(...) with inline=True in lax_numpy.py, even if they run alright in user code, the API defined is still "wrong".

And here is a bug(?) I found regarding #9342. Consider the following code:

from functools import partial
from jax import jit

# @jit
# @partial(jit, inline=True)
def f(init, lower):
  if lower < 1:
    return 0.
  else:
    return g(init, lower - 1)

# @jit
# @partial(jit, inline=True)
def g(init, lower):
  if lower < 1:
    return 0.
  else:
    return f(init, lower - 1)

# @jit
def h():
  return f(0., 2)

h()
  1. Everything runs if nothing is jitted.
  2. If f and g are @jit-ed and h is not, we see the familiar ConcretizationTypeError as expected. All good here.
  3. If f, g and h are all @jit-ed, we again see the expected error which is ultimately the ConcretizationTypeError. The error message is a bit off (also seen on the main branch).
  4. If f and g are @partial(jit, inline=True)-ed and h is not, ConcretizationTypeError. Comparable with the "grad-of-jit doesn't work, yet jit-of-grad-of-jit does?" issue (not surprising to me).
  5. If f and g are @partial(jit, inline=True)-ed and h is @jit-ed, should work as 1. (if I understand make jit(inline=True) do trace-time specializing #9342 correctly), yet JAX diverges. <- BUG HERE?
    Notice one sees the same error as 3. on main.
  6. Lastly, if one types them correctly (up to my above-mentioned definition, with @partial(jit, static_argnums=(1,))), everything works.

@patrick-kidger
Copy link
Collaborator Author

@soraros I don't think I agree. In terms of type-annotations, the point is that we want to express overloads of a function.

Using lax.fori_loop as an example. From the docs, it has type signature fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a. In this case, Int can either represent something Concretizable, or not. At trace time multiple dispatch (over lower, upper) then dispatches to the appropriate implementation.

(Note that there's no real way to express this in Haskell-like notation, since Haskell has sum types instead of union types.)

As it stands there is no way to wrap lax.fori_loop into a jax.jit, because this kills off the possibility of multiple dispatch; each argument has to be either traced or static'd.

@soraros
Copy link

soraros commented Jan 27, 2022

@patrick-kidger

Int can either represent something Concretizable, or not.

I agree, yet I think traced numerical arguments are just less restrictive than static ones. So fori_loop could be typed with haskell-like type system just fine:

fori_loop :: (Integral i, Concretizable i, Integral j) =>
  i -> i -> ((j, a) -> a) -> a -> a

assuming

instance Integral Int where ...
instance Integral TracedInt where ...
instance Concretizable Int where
  concretize = id

Maybe my understanding is fundamentally flawed, but could you give an example where lower and upper don't need to be concrete values at trace time (modulo DShapedArray)?

@patrick-kidger
Copy link
Collaborator Author

I agree, yet I think traced numerical arguments are just less restrictive than static ones.

I disagree with this. Traced arguments have advantages (no recompilation) and disadvantages (non-backpropagation through an iteration of unknown length); static arguments have advantages (backprop through iterations; more specialised/efficient compiled code) and disadvantages (recompilation).

So fori_loop could be typed with haskell-like type system just fine:

This doesn't take into account the possibility of traced lower, upper arguments; on which note:

could you give an example where lower and upper don't need to be concrete values at trace time

Consider running an RNN over a sequence whose length isn't known until runtime.

As a more advanced example in the same spirit: considering writing an ODE solver. Sometimes you take fixed steps and the number of steps is known in advance. Sometimes you take adaptive steps and the number of steps is not known until run time. It can be desirable to produce specialised code for each case, whilst handling both cases through a single interface, much like lax.fori_loop.

@soraros
Copy link

soraros commented Jan 27, 2022

@patrick-kidger Ahh, my bad, sorry for being slow on your points. Just realized that sometimes lax.fori_loop is really just a lax.while_loop, hence allowing traced lower and upper. Somehow I only think of fori_loop in terms of scan. Then I agree typing fori_loop is not as easy as setting static arguments and

As it stands there is no way to wrap lax.fori_loop into a jax.jit, because this kills off the possibility of multiple dispatch; each argument has to be either traced or static.

@mattjj
Copy link
Collaborator

mattjj commented Jan 28, 2022

@soraros whoa, any idea why your example (5) diverges? I ran it and I couldn't immediately tell what was going on. (Thanks for finding the error message bug too!)

I should've mentioned earlier: static_argnums doesn't always work, e.g. in the OP example, because static arguments must be hashable, and Tracers (like those involved in the grad call in the OP) aren't hashable. I'm not sure if there's an easy fix in that direction.

Another thing I should've mentioned earlier: fori_loop sometimes lowers to a while_loop and sometimes a scan. That is part of what's making the semantics discussion here confusing, I think.

@soraros
Copy link

soraros commented Jan 28, 2022

@mattjj If you disable FLAGS.experimental_cpp_jit it works, so I guess that's where the problem lies?
And here is an slightly smaller repo:

from functools import partial

import jax
from jax import jit

jax.config.update('jax_platform_name', 'cpu')

@partial(jit, inline=True)
def f(rec):
  if rec:
    return f(False)
  else:
    return 0

@jit
def h():
  return f(True)

h()

Somehow the inner call, namely f(False), never returns.

@mattjj
Copy link
Collaborator

mattjj commented Mar 16, 2022

Thanks for sharing that finding. There's a recently-uncovered bug in the C++ jit dispatch path which might explain the deadlock.

I've lost the plot on this issue thread. I could page it back in but I'm hoping that you folks can help me avoid that...

@patrick-kidger and @soraros is the conclusion here that we should just close the issue? Or land #9342? Or figure out something else to do?

@patrick-kidger
Copy link
Collaborator Author

I think land #9342 for certain.

Once that's in then I dont think there's ever a reason to use jax.jit without inline=True, actually. It slightly improves tracing with no downside. (Unless you really care about having a jaxpr that looks slightly different for some reason.)

JoeyTeng added a commit to JoeyTeng/jaxrenderer that referenced this issue Jun 5, 2023
This should be more beneficial, as followed by the discussions in Jax repository, see
jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342
JoeyTeng added a commit to JoeyTeng/jaxrenderer that referenced this issue Jun 12, 2023
This should be more beneficial, as followed by the discussions in Jax repository, see
jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342
JoeyTeng added a commit to JoeyTeng/jaxrenderer that referenced this issue Jun 12, 2023
* refactor(jit): make all jitted functions "inline=True"

This should be more beneficial, as followed by the discussions in Jax repository, see
jax-ml/jax#6584 jax-ml/jax#6681 jax-ml/jax#9298 jax-ml/jax#9342

* perf(pipeline): try to render 4/2/1 rows per batch using vmap to reduce fori_loop iterations

* feat(_meta_utils): simple way to add multiple trace annotations together for functions

add `@ad_tracing_name` to most functions to assist profiling
also bump to Python 3.10

BREAKING CHANGE: Now requires Python 3.10

* perf(pipeline): big refactor to not updating per rows, but renders all rows then concat and merge

* perf(pipeline): using scan + unroll (equiv map + unroll)

This is very similar to map + vmap (minibatch processing) as the inner
loop is too complex

* build(pyproject): try to relax jax{lib}'s verion constraint to ">=0.3.25,<5.0.0"

* test(pre-gen-brax): example inputs for profiling

* perf: try to eliminate all `lax.cond`

under `vmap`, `lax.cond` are lowered to `select_n` in HLO which leads to execution in both branches,
thus fails to 1) save computation when possible; 2) prevent unexpected values to be
produced/unexpected branches to be executed (defensive), thus let the non-dummy branch to be
executed anyway and only rule-out garbage value at the final stage all together to try to improve
performance. See google/brax#8409 for more details about unconditional executation of cond under
vmap

* fix(pipeline): gl_FrontFacing: fix its determination in pipeline

`True` if NOT back-facing

* perf: added extra stage in pipeline, aiming to interpolate and shade only one fragment per pixel

* docs(changelog): expose option `loop_unroll`; dependency version change

Bump minimum Python version from 3.9 to 3.10;
lower minimum jax & jaxlib to 0.3.25.

* build(pyproject): bump to 0.3.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants