Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using _is_jax_data for tree flattening results in incompatibility with some tree_map operations #193

Open
rciric opened this issue Aug 27, 2022 · 1 comment

Comments

@rciric
Copy link

rciric commented Aug 27, 2022

In certain cases, it seems like using _is_jax_data as a criterion for flattening trees can lead to structural incompatibilities, which can then result in errors when mapping over trees derived from distrax distributions.

To elaborate: let's consider that we have a model represented as a PyTree of parameters and metadata, and that this model contains a distrax distribution (or more generally Jittable) as a child node. We now wish to perform some selective update or partition operation on our model tree — for instance, to separate the tree into DeviceArray leaves and non-DeviceArray leaves. To do this, we will first perform a tree_map on our existing tree, mapping leaves that match the selection criterion to True and leaves that don't match to False. We will then use this mapped “mask” tree to specify the leaves to set to None on either side of the partition.

Unfortunately, this is where we hit a snag. Since our mask tree now contains boolean values in place of DeviceArrays, _is_jax_data will return False for our mask tree where it returned True for the original tree, and the children field could be left empty for the mask tree. Because the flattened distribution and mask trees do not thereafter share the same structure, we cannot use the mask tree as needed to create our partition. (Side note: Even if we didn't create a mask tree for our partition, we'd still end up with None on the side of the partition without DeviceArrays, ultimately resulting in the same structural incompatibility if we later wish to undo the partition.) I'm not actually sure whether the data-based flattening switch is the only cause here, but wanted to share my observations.

Here is a minimal reproducible example demonstrating the issue:

import jax, distrax

tree = distrax.Normal(0, 1)
mask = jax.tree_util.tree_map(lambda _: True, tree)

jax.tree_util.tree_map(
    lambda l, r: l,
    tree,
    mask
)

Results in:

ValueError: Mismatch custom node data:
([None, None], [True, True], PyTreeDef({'_loc': *, '_scale': *})) !=
([True, True], [False, False], PyTreeDef({'_loc': *, '_scale': *}));
value: <distrax._src.distributions.normal.Normal object at 0x7fd388c36680>.

As a result of this design choice, distrax distributions are not currently compatible with equinox’s filter transforms, like eqx.filter_jit. This doesn't actually matter much for my use case — I can mark any model fields that are distrax.Distributions as static without recompiling since the instance doesn't change — but it is possible there are other use cases where this could make a difference.

Details

JAX v0.3.16
distrax v0.1.2 (nightly from c013670)
Running on CPU

rciric added a commit to hypercoil/hypercoil that referenced this issue Aug 27, 2022
@patrick-kidger
Copy link

Haha, I just stumbled across this issue as well.

Whilst we're here, _is_jax_data also uses a try-except around abstractify here which is probably going to really hurt performance.

Equinox actually used to do something similar to filter arrays from non-arrays, but switched to doing instance checks (e.g. isinstance(leaf, (np.ndarray, jnp.ndarray))) because this approch can be very slow.

Modulo these issues, the Jittable base class used here is basically doing the same thing as eqx.Module. I think realistically it's unlikely to happen (this repo hasn't seen any activity in a while) but one possible fix might be to replace it with eqx.Module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants