Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.vmap fails for custom PyTree #16170

Closed
francois-rozet opened this issue May 29, 2023 · 14 comments
Closed

jax.vmap fails for custom PyTree #16170

francois-rozet opened this issue May 29, 2023 · 14 comments
Assignees
Labels
bug Something isn't working

Comments

@francois-rozet
Copy link

francois-rozet commented May 29, 2023

Description

I implemented a custom PyTree class that automatically separates array and non-array leaves. jax.vmap fails when the input/output is an instance of that class.

import jax

@jax.tree_util.register_pytree_node_class
class Custom:
    def __init__(self, key, switch=True):
        if switch:
            self.x = jax.random.normal(key)  # not-static
        else:
            self.x = None  # static

    def tree_flatten(self):
        if isinstance(self.x, jax.Array):
            return [self.x], [None]
        else:
            return [None], [self.x]

    @classmethod
    def tree_unflatten(cls, static, children):
        y, x = static[0], children[0]
        x = x if y is None else y

        self = object.__new__(cls)
        self.x = x

        return self

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 4)

jax.vmap(Custom)(keys)

fails with

Traceback (most recent call last):
  File ".../demo.py", line 42, in <module>
    print(jax.vmap(Custom)(jax.random.split(key, 2)))
  File ".../miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File ".../miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/api.py", line 1239, in vmap_f
    out_flat = batching.batch(
  File ".../miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/linear_util.py", line 203, in call_wrapped
    ans = gen.send(ans)
  File ".../miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/interpreters/batching.py", line 567, in _batch_inner
    out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
  File ".../miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/api.py", line 1241, in <lambda>
    lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
  File ".../miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/api_util.py", line 424, in flatten_axes
    assert len(axes) == treedef.num_leaves
jax._src.traceback_util.UnfilteredStackTrace: AssertionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../demo.py", line 42, in <module>
    print(jax.vmap(Custom)(jax.random.split(key, 2)))
AssertionError

I believe the error comes from the way jax.vmap tries to assign vectorized axis to the tree leaves.

https://github.com/google/jax/blob/ae9160a4e9f14992c9f53a38a1aeaf146eacbf16/jax/_src/api_util.py#L395-L400

Here, it is assumed that unflattening the treedef with object() leaves instead of the original leaves, then flattening the obtained tree will lead to the same (number of) leaves. This is not the case with my custom PyTree. I see two ways of fixing this:

  1. Using another placeholder than object() which can be used in custom PyTrees to ensure that leaves stay the same.
  2. Refactor jax.vmap to not rely on placeholder leaves.

What jax/jaxlib version are you using?

jax 0.4.10, jaxlib 0.4.10

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.9.16, Ubuntu 22.04

NVIDIA GPU info

No response

@francois-rozet
Copy link
Author

francois-rozet commented May 31, 2023

Hello @mattjj and @hawkinsp, (tagging because you are mentioned in the code). I took a look at the way vmap works.

If I understand correctly, the core (batching.batch) only works with flattened inputs and outputs, which requires to get flattened in_axes and out_axes. But as we don't have access to the outputs tree yet, it is not possible to flatten out_axes upfront. To solve this issue, the implementation uses intertwined generators, which is quite convoluted.

I came up with a simpler implementation, which does not rely on generators and does not require to replace the leaves of the trees to flatten the axes.

import jax
import jax.tree_util as jtu


def flatten_axes(tree, axes):  # /!\ uses the tree instead of the treedef
    flat_axes = []

    def add_axes(axis, x):
        flat_axes.extend([axis] * len(jtu.tree_leaves(x)))

    jtu.tree_map(add_axes, axes, tree)

    return flat_axes


def vmap(fun, in_axes, out_axes, axis_size, axis_name, spmd_axis_name):
    def wrapped(*ins):
        flat_in_axes = flatten_axes(ins, in_axes)
        flat_ins, treedef_ins = jtu.tree_flatten(ins)

        if axis_size is None:
            axis_size = mystery_logic(flat_ins, flat_in_axes)

        flat_out_axes = []
        store = []

        def flat_fun(flat_ins):
            ins = jtu.tree_unflatten(treedef_ins, flat_ins)
            outs = fun(*ins)

            flat_out_axes.extend(flatten_axes(outs, out_axes))
            flat_outs, treedef_outs = jtu.tree_flatten(outs)

            store.append(treedef_outs)

            return flat_outs

        flat_outs = flat_vmap(
            flat_fun,
            flat_in_axes,
            flat_out_axes,  # filled when flat_fun is called
            axis_size,
            axis_name,
            spmd_axis_name,
        )(flat_ins)

        treedef_outs = store.pop()

        return jtu.tree_unflatten(treedef_outs, flat_outs)

    return wrapped

I don't really understand all the technicalities in _mapped_axis_size and batching.batch so I wrote mystery_logic and flat_vmap instead. This implementation should be faster, as it flattens/unflattens less trees and does not use generators. Tell me what you think!

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 1, 2023

The issue here is that your pytree flattening is conditioned on whether the attributes are array instances – this breaks vmap and other JAX transforms, which implicitly assume that flattening behavior does not change depending on the type of the attributes. Instead, I would implement your flattening without that logic – something like this:

    def tree_flatten(self):
        return [self.x], [None]

@francois-rozet
Copy link
Author

francois-rozet commented Jun 1, 2023

Hello @jakevdp, thank you for your comment, but I actually need that logic to make smart pytrees that are easy to use with jit, grad and vmap. It is already working fine with all of them, except when the smart pytree is one of the inputs/outputs of vmap.

I know that vmap assumes that the tree structure does not changes with respect to the leaves (I already mention it in the description of the issue), but I think this is a bug and not a feature. If JAX transformations cannot handle smart pytrees (that perfectly respect the pytree API otherwise), then what is the point of register_pytree anyway? This limitation is one of the reasons why it is very hard to make a simple (i.e. without many layers of abstraction) NN library based on JAX.

The new implementation of jax.vmap I propose is simpler, (likely) faster and allows for arbitrary custom pytrees as it does not assume the tree structure is independent from the leaf types.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 1, 2023

OK, thanks for the context. I'm assigning @mattjj who could take a look at the alternate vmap implementation proposal.

@patrick-kidger
Copy link
Collaborator

@francois-rozet -- you may like Equinox, which does many of the things you're discussing here.

@francois-rozet
Copy link
Author

francois-rozet commented Jun 2, 2023

@francois-rozet -- you may like Equinox, which does many of the things you're discussing here.

Hello, @patrick-kidger I know equinox, I am actually trying to improve it 😅 (I want to make it more compatible with the default jax transformations, without having to specify static fields), I'll get to you when I have a working proof of concept!

@patrick-kidger
Copy link
Collaborator

Ah, I've actually figure out the answer to that one! Don't use dataclasses (like Equinox does). Then override __setattr__ to detect what your fields are during __init__. (Or maybe look in __dict__.)

Either way, then inspect the value specifically at init time. If it's an array, mark it dynamic. If it's anything else, mark it static. This preserves pytree semantics as the inspection is only done at init time, and not later during flattening. (After which e.g. leaves may have substituted out with None for the sake of vmap(in_axes=...), etc.)

I've been contemplating changing Equinox over to the above approach. It's not clear to me what the consequences to backward compatibility are, though.

@francois-rozet
Copy link
Author

francois-rozet commented Jun 2, 2023

I think this approach would break if you want a "list of arrays", a "dict of modules", a list of Partial objects with modules as functions, or user defined pytrees ... inside a Module. It is impossible to enumerate all cases because there is an infinity. And modifying __setattr__ is actually fairly dangerous (there are many edge cases). So I have also moved away from dataclassses and fields, and work with the tree directly. I detect leaves that are either modules or arrays and set everything else as a static during the flattening. The advantage is that you can actually modify anything at run time (the type or value of a leaf or even the tree structure).

I have a proof of concept in this repo: https://github.com/francois-rozet/inox. The name is a tribute to equinox, and means stainless in French. The core are the Namespace and Module classes (in inox/tree_util.py and inox/nn/module.py) which have custom tree_flatten and tree_unflatten methods.

There is a demo.py example. For a few days of work I think this is coming fine! I still have to add BatchNorm and Dropout, which should be ok with the Buffer concept. Tell me what you think!

Edit: Just added BatchNorm and Dropout 😃

@patrick-kidger
Copy link
Collaborator

I think this approach would break if you want a "list of arrays"

I think this is still doable, by flattening every field into a list of leaves, and checking each one. All non-arrays can be wrapped into a Static object. Then unwrap it during __getattr__, so that the wrapping is invisible to an end user.

This would also be quite nice in that it means we might be able to avoid the confusion over when to use static fields, by simply not having them be part of the public API any more.

WDYT?

I detect leaves that are either modules or arrays and set everything else as a static during the flattening.

The problem is that this just isn't compatible with JAX's model of what a pytree is. (I.e. that the structure does not depend on the types of the leaves.) And indeed Equinox makes heavy use of this invariant too: for example equinox.tree_at does some pretty clever leaf-substituting magic in order to perform its out-of-place updates, and as such even has a specific check that this invariant is satisfied.

The name is a tribute to equinox

Haha, thank you! That's great to see.

Tell me what you think!

I like it! I'm now wondering why I ever called it eqx.partition and not eqx.tree_partition, which would have been much more consistent.

I would caution against your approach to statefulness, using in-place updates. This gets pretty hairy around JAX transforms: if I do

my_model = ...  # uses batch norm internally
leaves, treedef = jtu.tree_flatten(my_model)
my_model2 = jtu.tree_unflatten(treedef, leaves)
my_model2(...)

then the original my_model will not see the batch norm updates. And such flattening/unflattening happens every time you cross a JIT/grad/vmap/etc. API boundary, so it's an easy footgun to hit accidentally.

@francois-rozet
Copy link
Author

francois-rozet commented Jun 3, 2023

@patrick-kidger Let's take this discussion elsewhere (francois-rozet/inox#1), as it is not directly related to the issue I submitted.

AndPotap referenced this issue in wilson-labs/cola Aug 28, 2023
Address compatibility for pytrees both in Jax and PyTorch, to allow
vmap, jit with LinearOperator arguments or outputs. See e.g. #20 or #26.
Will also enable using CoLA operators within
[equinox](https://github.com/patrick-kidger/equinox) modules.

Replaces custom flattening function (which is slow and not very general)
with optree (for pytorch and numpy backends) and jax.tree_utils for jax.
Modifies __setattr__ of LinearOperator to record which of vars(obj) are
dynamic (pytrees or arrays) or static (other).
Then flatten and unflatten can separate the two. This annotation of
which are static and which dynamic needs to be done during the init as
discussed in this jax issue
[https://github.com/google/jax/issues/16170](https://github.com/google/jax/issues/16170)
(even though doing so inside flatten is ostensibly compatible with the
pytree specification).
LinearOperator metaclass is set to one which automatically registers
each linear operator as a pytorch pytree (if installed), optree pytree,
and jax pytree (if installed). Optree is used for numpy linear operators
and will also eventually replace the pytorch pytrees as per the
intentions of the pytorch devs.


With this functionality it should also be possible to construct batched
linear operators and use them in batched matrix routines.
E.g. 
```python
Ds = vmap(cola.ops.Diagonal)(randn(10,1000))
logdets = vmap(cola.linalg.logdet)(Ds)
```

It may also be possible to simplify our custom autograd rules once
LinearOperators are pytrees though I have not attempted this yet.
@francois-rozet
Copy link
Author

francois-rozet commented Jan 4, 2024

Hello @jakevdp, @mattjj 👋 I have found a workaround for this issue which makes custom auto-detect non-array leaves PyTrees compatible with the current jax.vmap. The idea is to detect whether the leaf has been replaced by an object() instance (which only happens in jax.vmap).

import jax

@jax.tree_util.register_pytree_node_class
class Custom:
    def __init__(self, key, switch=True):
        if switch:
            self.x = jax.random.normal(key)  # not-static
        else:
            self.x = 'static'

    def tree_flatten(self):
        if isinstance(self.x, jax.Array) or type(self.x) is object:  # also check whether x is an object()
            return [self.x], [None]
        else:
            return [None], [self.x]

    @classmethod
    def tree_unflatten(cls, static, children):
        y, x = static[0], children[0]
        x = x if y is None else y

        self = object.__new__(cls)
        self.x = x

        return self

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 4)

jax.vmap(Custom)(keys)

However, I still think that the sketch implementation I propose at #16170 (comment) is better than the current one because it is (a) simpler as it does not uses generators, (b) faster as it flattens/unflattens less trees and (c) allows for arbitrary custom PyTrees as it does not assume the tree structure is independent from the leaf types.

Note that, in fact, JAX already does not assume that the tree structure is independent from the leaf type as replacing a leaf by None leads to a different tree structure and number of leaves.

@yashk2810
Copy link
Collaborator

Won't the pytree registry help here?

@francois-rozet
Copy link
Author

Won't the pytree registry help here?

I am not sure what you mean.

@patrick-kidger
Copy link
Collaborator

@francois-rozet -- I would really strongly discourage this leaf-type-detection approach. Since JAX pytrees can in principle have leaves of any type, then third-party code can and does make use of this. Your work may be compatible with core JAX (or at least any errors are passing silently, rather than failing loudly), but it might break when used with other libraries.

As an example of another library running afoul of this, see google-deepmind/distrax#193. Again there is leaf type detection during flattening, and this then breaks things later.

I think if you're interested in doing auto-[dynamic/static] leaf type detection, then I'd propose doing it during __init__ time, and storing the result in the flattened aux data, so that the result then remains consistent across all future operations.

@francois-rozet francois-rozet closed this as not planned Won't fix, can't repro, duplicate, stale Jan 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants