-
-
Notifications
You must be signed in to change notification settings - Fork 153
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Target JAX's intermediate representation with the JAX linker #1014
Comments
Here are a couple of issues for which the idea was considered as a possible solution: |
Can we actually circumvent these limitations this way? In addition, would Aesara still be able to generate JAX code that we can pass around to other libraries (e.g., BlackJax which will call grad/JIT on a user defined JAX function)? |
This is a legitimate concern, and something we should figure out before investing too much time in it. |
To follow up on the previous discussion. We were considering the following function: from jax import lax
from jax import make_jaxpr
import jax.numpy as jnp
def add_fn(a, b):
return lax.add(a, b)
x = jnp.array([1., 1.])
y = jnp.array([2., 3.])
add_fn(x, y)
# [3., 4.] JAX traces the user's functions to translate them to (Closed) JAXPRs, and those contain information about the shape and type of inputs: from jax import make_jaxpr
add_jaxpr = make_jaxpr(add_fn)(x, y)
add_jaxpr
# { lambda ; a:f32[2] b:f32[2]. let c:f32[2] = add a b in (c,) }
add_1d_jaxpr = make_jaxpr(add_fn)(1., 1.)
add_1d_jaxpr
# { lambda ; a:f32[] b:f32[]. let c:f32[] = add a b in (c,) } the JAXPRs are objects: add_1d_jaxpr.jaxpr.eqns
# [a:f32[] = add b c]
add_1d_jaxpr.jaxpr.invars
# [a, b]
add_1d_jaxpr.jaxpr.outvars
# [c] More interestingly, we can get an object that behaves like a function from from jax.core import jaxpr_as_fun
add_1d = jaxpr_as_fun(add_1d_jaxpr)
add_1d
# functools.partial(<function jaxpr_as_fun at 0x7f115d9ece50>, { lambda ; a:f32[] b:f32[]. let c:f32[] = add a b in (c,) }) This needs to be double checked but it seems that no tracing is happening anymore; I can for instance pass the add_1d(x, y)
# [DeviceArray([3., 4.], dtype=float32)] as explained in the internals documentation the interpreter itself is tracebale so we can JIT-compile this function: import jax
jitted_add1 = jax.jit(add_1d)
make_jaxpr(jitted_add1)(1., 1.)
# { lambda ; a:f32[] b:f32[]. let
# c:f32[] = xla_call[
# call_jaxpr={ lambda ; d:f32[] e:f32[]. let f:f32[] = add d e in (f,) }
# name=<unnamed wrapped function>
# ] a b
# in (c,) }
jitted_add1(1., 2.)
# [DeviceArray(3., dtype=float32, weak_type=True)] It feels safe to target Unrelated noteWe should be able to determine the largest jit-able (sub)set of the code doing static analysis of the corresponding It may still be possible to As far as I understand the motivation behind the This file is a good starting point for the translations from Ops to XLA. I see mentions to MLIR in this file; if XLA can interpret MLIR we may want to directly target MLIR. There is a roadmap, but hard to know whether this is going to be done and when; if JAX starts lowering to MLIR there's a good chance this will happen? |
Cool. I assume other transformations like grad and vmap, can also be performed in the same way you could do jit, after calling |
However, |
It is clear now that by targeting JAX's IR directly we would still be able to use What we do from here depends on the goals we set for the transpilation: if it's compatibility with the JAX ecosystem then the approach that the dispatcher currently takes is the most appropriate. If we want to target XLA while avoiding JAX's self-imposed limitations (aka build a JAX replacement of sort) then we might as well go all the way and target XLA's IR directly and use I believe that short term we should aim for compatibility with the broader JAX ecosystem. It is fairly simple, allows Aesara to piggyback on a much broader ecosystem, and we all know the size of the ecosystem is critical when it comes to adoption. We can however address some of the issues that motivated this thread by working Aesara side: for instance by making sure that shapes that are known to be constant at compile time are indeed set to a constant value before compiling. When it comes to known limitations of JAX like dynamic shapes we can fail gracefully and explain that this is due to a limitation on JAX's side. For things that JAX traces out like assert statements, I would simply warn the user it has been removed because of a limitation on JAX's side. Users still get the many benefits of Aesara like its rewrite system, while being able to use their favorite library (hopefully they will eventually see the interest in porting said library to Aesara). Nevertheless, XLA remains an interesting target in itself for GPU and TPU. I think it is worth diving into the XLA documentation directly and figure out what we may gain from bypassing JAX altogether. |
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
The JAX linker currently target the library's numpy-like high level API. For instance, the
Dot
Op is translated usingjax.numpy.dot
:However JAX is a symbolic library (albeit a limited one) and has its own intermediate representation. When the user calls a function written with
jax.numpy
primitive for the first time, JAX traces the function and converts it to a Jaxpr that is then processed by XLA. Therefore, when one transpiles their Aeasara code to JAX and runs the resulting code is traced. This is completely unnecessary since all the information needed to build JAX's intermediate representation is already contained in the Aesara graph.We could therefore, in theory, translate Aesara's Ops directly to JAX's intermediate representation. We would not only improve runtime performance (gain to be estimated), but also have more freedom for the transpilation since we won't be limited to JAX's high level API.
Proof of concept
Before opening a PR, I will try in the comments of this issue to translate the following Aesara graph:
To its JAX equivalent:
This example is simple, but raises the question of how types and shapes are handled in JAX's IR. In particular, I am currently not sure that JAX can handle arrays of unknown (but fixed) length. If it cannot we can imagine a "delayed transpilation" where Aesara would generate JAX's IR when the function is called with arguments.
The text was updated successfully, but these errors were encountered: