Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Target JAX's intermediate representation with the JAX linker #1014

Closed
rlouf opened this issue Jun 24, 2022 · 7 comments
Closed

Target JAX's intermediate representation with the JAX linker #1014

rlouf opened this issue Jun 24, 2022 · 7 comments
Assignees
Labels
enhancement New feature or request important JAX Involves JAX transpilation refactor This issue involves refactoring

Comments

@rlouf
Copy link
Member

rlouf commented Jun 24, 2022

The JAX linker currently target the library's numpy-like high level API. For instance, the Dot Op is translated using jax.numpy.dot:

@jax_funcify.register(Dot)
def jax_funcify_Dot(op, **kwargs):
    def dot(x, y):
        return jnp.dot(x, y)

    return dot

However JAX is a symbolic library (albeit a limited one) and has its own intermediate representation. When the user calls a function written with jax.numpy primitive for the first time, JAX traces the function and converts it to a Jaxpr that is then processed by XLA. Therefore, when one transpiles their Aeasara code to JAX and runs the resulting code is traced. This is completely unnecessary since all the information needed to build JAX's intermediate representation is already contained in the Aesara graph.

We could therefore, in theory, translate Aesara's Ops directly to JAX's intermediate representation. We would not only improve runtime performance (gain to be estimated), but also have more freedom for the transpilation since we won't be limited to JAX's high level API.

Proof of concept

Before opening a PR, I will try in the comments of this issue to translate the following Aesara graph:

import aesara
import aesara.tensor as at

a = at.vector()
b = at.vector()

c = a + b

aesara.dprint(c)
# : Elemwise{add,no_inplace} [id A]
# :  |<TensorType(float64, (None,))> [id B]
# :  |<TensorType(float64, (None,))> [id C]

To its JAX equivalent:

import jax.numpy as np
from jax import lax
from jax import make_jaxpr

def add_fn(a, b):
    return lax.add(a, b)

print(make_jaxpr(add_fn)(np.array([1., 1.]), np.array([1., 1.])))
# { lambda ; a:f32[2] b:f32[2]. let c:f32[2] = add a b in (c,) }

This example is simple, but raises the question of how types and shapes are handled in JAX's IR. In particular, I am currently not sure that JAX can handle arrays of unknown (but fixed) length. If it cannot we can imagine a "delayed transpilation" where Aesara would generate JAX's IR when the function is called with arguments.

@rlouf rlouf self-assigned this Jun 24, 2022
@rlouf rlouf added the JAX Involves JAX transpilation label Jun 24, 2022
@brandonwillard brandonwillard added enhancement New feature or request important refactor This issue involves refactoring labels Jun 24, 2022
@brandonwillard
Copy link
Member

Here are a couple of issues for which the idea was considered as a possible solution:

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jun 26, 2022

Here are a couple of issues for which the idea was considered as a possible solution:

Can we actually circumvent these limitations this way?

In addition, would Aesara still be able to generate JAX code that we can pass around to other libraries (e.g., BlackJax which will call grad/JIT on a user defined JAX function)?

@rlouf
Copy link
Member Author

rlouf commented Jun 26, 2022

In addition, would Aesara still be able to generate JAX code that we can pass around to other libraries (e.g., BlackJax which will call grad/JIT on a user defined JAX function)?

This is a legitimate concern, and something we should figure out before investing too much time in it.

@rlouf
Copy link
Member Author

rlouf commented Jul 4, 2022

To follow up on the previous discussion. We were considering the following function:

from jax import lax
from jax import make_jaxpr
import jax.numpy as jnp

def add_fn(a, b):
    return lax.add(a, b)

x = jnp.array([1., 1.])
y = jnp.array([2., 3.])
add_fn(x, y)
# [3., 4.]

JAX traces the user's functions to translate them to (Closed) JAXPRs, and those contain information about the shape and type of inputs:

from jax import make_jaxpr

add_jaxpr = make_jaxpr(add_fn)(x, y)
add_jaxpr
# { lambda ; a:f32[2] b:f32[2]. let c:f32[2] = add a b in (c,) }

add_1d_jaxpr = make_jaxpr(add_fn)(1., 1.)
add_1d_jaxpr
# { lambda ; a:f32[] b:f32[]. let c:f32[] = add a b in (c,) }

the JAXPRs are objects:

add_1d_jaxpr.jaxpr.eqns
# [a:f32[] = add b c]
add_1d_jaxpr.jaxpr.invars
# [a, b]
add_1d_jaxpr.jaxpr.outvars
# [c]

More interestingly, we can get an object that behaves like a function from ClosedJaxprs using what devs call an interpreter:

from jax.core import jaxpr_as_fun

add_1d = jaxpr_as_fun(add_1d_jaxpr)
add_1d
# functools.partial(<function jaxpr_as_fun at 0x7f115d9ece50>, { lambda ; a:f32[] b:f32[]. let c:f32[] = add a b in (c,) })

This needs to be double checked but it seems that no tracing is happening anymore; I can for instance pass the x and y arrays to the function build form the JAXPR obtained with tracing with scalars:

add_1d(x, y)
# [DeviceArray([3., 4.], dtype=float32)]

as explained in the internals documentation the interpreter itself is tracebale so we can JIT-compile this function:

import jax

jitted_add1 = jax.jit(add_1d)
make_jaxpr(jitted_add1)(1., 1.)
# { lambda ; a:f32[] b:f32[]. let
#     c:f32[] = xla_call[
#       call_jaxpr={ lambda ; d:f32[] e:f32[]. let f:f32[] = add d e in (f,) }
#       name=<unnamed wrapped function>
#     ] a b
#   in (c,) }
jitted_add1(1., 2.)
# [DeviceArray(3., dtype=float32, weak_type=True)]

It feels safe to target Jaxprs for now. The next step is to build the function add_fn by building the ClosedJaxp manually (i.e. not by tracing a python function). Then we will try to understand what happens when jax.jit traces evaluated Jaxprs.

Unrelated note

We should be able to determine the largest jit-able (sub)set of the code doing static analysis of the corresponding aesara graph. jit obeys to very simple rules and those can be checked at compile time. This may be an appreciated feature and potentially allow us to transpile code that has tensors of varying shapes, for instance.

It may still be possible to jit completely functions using e.g. jax.numpy.reshape, but we may need to implement our own jitting function (aesara.link.jax.jit). We need to explore XLA's primitives to see what the true limitations are here (and not those baked in JAX). We can use JAX merely as XLA python bindings and lower Jaxprs we created to functions.

As far as I understand the motivation behind the omnistaging change in JAX (jax-ml/jax#3370), the issues it tries to solve can be circumvented when one has a symbolic graph it can analyze.

This file is a good starting point for the translations from Ops to XLA. I see mentions to MLIR in this file; if XLA can interpret MLIR we may want to directly target MLIR. There is a roadmap, but hard to know whether this is going to be done and when; if JAX starts lowering to MLIR there's a good chance this will happen?

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 9, 2022

Cool. I assume other transformations like grad and vmap, can also be performed in the same way you could do jit, after calling jaxpr_as_fun?

@rlouf
Copy link
Member Author

rlouf commented Jul 9, 2022

Yes.

jax.grad requires tracing the function to build a "new graph" so it will not be possible to pass as an argument a function that is built this way. It is a minor inconvenience as Aesara can compute the gradients.

However, jax.jit, jax.vmap and jax.pmap (and the loops) would work with these functions.

@rlouf
Copy link
Member Author

rlouf commented Sep 13, 2022

It is clear now that by targeting JAX's IR directly we would still be able to use jax.jit, jax.vmap on the compiled function, but will not be able to apply transformations like jax.grad. There is no free lunch.

What we do from here depends on the goals we set for the transpilation: if it's compatibility with the JAX ecosystem then the approach that the dispatcher currently takes is the most appropriate. If we want to target XLA while avoiding JAX's self-imposed limitations (aka build a JAX replacement of sort) then we might as well go all the way and target XLA's IR directly and use jaxlib as a bridge.

I believe that short term we should aim for compatibility with the broader JAX ecosystem. It is fairly simple, allows Aesara to piggyback on a much broader ecosystem, and we all know the size of the ecosystem is critical when it comes to adoption. We can however address some of the issues that motivated this thread by working Aesara side: for instance by making sure that shapes that are known to be constant at compile time are indeed set to a constant value before compiling. When it comes to known limitations of JAX like dynamic shapes we can fail gracefully and explain that this is due to a limitation on JAX's side. For things that JAX traces out like assert statements, I would simply warn the user it has been removed because of a limitation on JAX's side. Users still get the many benefits of Aesara like its rewrite system, while being able to use their favorite library (hopefully they will eventually see the interest in porting said library to Aesara).

Nevertheless, XLA remains an interesting target in itself for GPU and TPU. I think it is worth diving into the XLA documentation directly and figure out what we may gain from bypassing JAX altogether.

@aesara-devs aesara-devs locked and limited conversation to collaborators Sep 14, 2022
@rlouf rlouf converted this issue into discussion #1184 Sep 14, 2022

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
enhancement New feature or request important JAX Involves JAX transpilation refactor This issue involves refactoring
Projects
None yet
Development

No branches or pull requests

3 participants