-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using _is_jax_data
for tree flattening results in incompatibility with some tree_map
operations
#193
Comments
Haha, I just stumbled across this issue as well. Whilst we're here, Equinox actually used to do something similar to filter arrays from non-arrays, but switched to doing instance checks (e.g. Modulo these issues, the |
In certain cases, it seems like using
_is_jax_data
as a criterion for flattening trees can lead to structural incompatibilities, which can then result in errors when mapping over trees derived fromdistrax
distributions.To elaborate: let's consider that we have a model represented as a PyTree of parameters and metadata, and that this model contains a
distrax
distribution (or more generallyJittable
) as a child node. We now wish to perform some selective update or partition operation on our model tree — for instance, to separate the tree intoDeviceArray
leaves and non-DeviceArray
leaves. To do this, we will first perform atree_map
on our existing tree, mapping leaves that match the selection criterion toTrue
and leaves that don't match toFalse
. We will then use this mapped “mask” tree to specify the leaves to set toNone
on either side of the partition.Unfortunately, this is where we hit a snag. Since our mask tree now contains boolean values in place of
DeviceArray
s,_is_jax_data
will return False for our mask tree where it returned True for the original tree, and thechildren
field could be left empty for the mask tree. Because the flattened distribution and mask trees do not thereafter share the same structure, we cannot use the mask tree as needed to create our partition. (Side note: Even if we didn't create a mask tree for our partition, we'd still end up withNone
on the side of the partition withoutDeviceArray
s, ultimately resulting in the same structural incompatibility if we later wish to undo the partition.) I'm not actually sure whether the data-based flattening switch is the only cause here, but wanted to share my observations.Here is a minimal reproducible example demonstrating the issue:
Results in:
As a result of this design choice,
distrax
distributions are not currently compatible withequinox
’s filter transforms, likeeqx.filter_jit
. This doesn't actually matter much for my use case — I can mark any model fields that aredistrax.Distribution
s as static without recompiling since the instance doesn't change — but it is possible there are other use cases where this could make a difference.Details
JAX v0.3.16
distrax v0.1.2 (nightly from c013670)
Running on CPU
The text was updated successfully, but these errors were encountered: