Designing smart modules #1
Replies: 7 comments 23 replies
-
I pushed a new branch ( Now remains to tackle the issue of the "stateful" layers, like BatchNorm and Dropout. I think a functional approach where the layers return their own state is too cumbersome and introduces a lot of complexity (for both users and developers). As you said, my approach fails silently when the module enters "pure function" territory, that is We can either warn the user and give them the responsibility to handle this or introduce small wrappers around |
Beta Was this translation helpful? Give feedback.
-
So my strongest opinion is about this statement:
I strongly disagree! Equinox makes heavy use of the assumption. I previously mentioned that it's what needed to make Regarding simplicity of Moving on to in-place updates: Equinox deliberately has frozen modules, that you cannot modify outside of |
Beta Was this translation helpful? Give feedback.
-
Once again I disagree. Changing this in JAX now would lead to further fragmentation in the ecosystem. (Which frankly we already have more than enough of -- Equinox vs Haiku vs Flax being the most prominent example.) Right now Equinox is compatible with every pytree, and thus compatible with essentially everything in the JAX ecosystem. If we make this distinction of a privileged "special" kind of pytree then that goes away -- we end up with a mini "Equinox ecosystem" as being the only place where we could make compatibility guarantees. From a library author perspective: the loss of a shared language makes it much harder for new projects to get started. From an end user perspective: it means that two unrelated libraries are much less likely to be compatible.
Flattening and unflattening happens implicitly in many places in JAX. Every time you cross
I'd argue this is a feature. For example when doing batch norm + neural ODEs, you will actually evaluate the layer multiple times, but only want to do the stateful update just once. (Else you get different normalisations at different points of the solve, which breaks hte ODE-like structure.) Decoupling evaluation from stateful update makes this possible. |
Beta Was this translation helpful? Give feedback.
-
Hello, I came into the described Jax problem; I have authored a pytree library and tested some of the suggestions/ideas first-hand that @francois-rozet / @patrick-kidger mentioned. From my experience as a user/designer of pytree-based libraries, it's tempting to add cool features like automatic detection of static fields/overriding some magic methods (like PyTorch itself does in 'Module'), etc. Still, this is fine if the user knows what's happening underneath. However, this is only true on occasion. For in-place updates, The Immutability assumption is a fundamental assumption to abide by; if you break this assumption, I am pretty sure you will encounter all sorts of problems (as I have discovered ) . For tree modification, in my library this P.S. I really like the logo. :D |
Beta Was this translation helpful? Give feedback.
-
Hello @patrick-kidger and @ASEM000 👋 I hope you are doing well. Just to say that I finally had the time to finish the interface of Inox. I think it is quite nice now! In the end, detecting and wrapping static leaves on I therefore went back to detecting static leaves during flattening, which is slightly slower (a factor 5x for flattening in my experiments) but remains negligible with respect to actual tensor operations. In addition, flattening/unflattening can be avoided (with Note that I previously had an issue with Thank you for the very interesting discussions! |
Beta Was this translation helpful? Give feedback.
-
Happy to hear from you again and happy holidays! So, IMO, I came to two conclusions 1) keep everything explicit 2) Same class instance should have the same flatten rule. For 1), The problem is a) you are doing extra host work every time your invoke for loop/tree_map + conditionals to decide which is static/frozen (this scales with your leaves count). The performance hit becomes visible for small nets. b) If you hide the process from the user, then detecting what is trainable/not trainable is maybe a footgun here, recompiling behavior becomes unclear here. For 2), I think I want to follow standard pytree rules like list/dict where any instance of them has the same rule. this simplify the reasoning about them. so for me I prefer flatten step ~ vars(tree) to match dict. This speeds up the overall performance compared to other impls. My current approach is to use leafless wrapper (
Correct me If im wrong, your main motivation is to make static/non-static decisions hidden from the user on a module-level, (smart modules) so you dont have to recreate all jax transforms to understand this (like Patrick's work with filtered jax transforms), I think this is a good goal but my feeling (might be incorrect ofc) is that, eventually users have to understand this point in case they are hit with some error/weird behavior when interacting with other libraries or jax transforms.You can see a sample issue either in flaxhere/equinoxhere and here where there is some confusion about the static/non-static and jaxtype/non-jaxtype behavior. I think if we promise the user that we will handle everything on their behalf, its frustaing when they hit some error/behavior that they dont understand and then they have to understand this behavior to debug it.So we either keep our promise and handle everything on their behalf (which is hard) if we are dealing on low-level jax, or we do not make a promise and introduce the concept of static/non-static jaxtype/non-jaxtype to the user at the expense of steeper learning curve. I am interested to know about you statefulness handling, I will read more about your stateful modules as this is (IMO) one of tricky problems in pytree-based approaches. Let me know what you think. |
Beta Was this translation helpful? Give feedback.
-
Hello @patrick-kidger and @ASEM000, sorry to bother you again. I tried my last idea (see #1 (reply in thread)) and it was terrible 🤣 In the end, I think you guys are right: automatic detection of static leaves during flattening makes Inox modules (a) error prone around JAX transformations and (b) incompatible with part of the JAX ecosystem. Therefore, I decided to make my modules dumb PyTrees (namespaces) and adopted the lifted transformation approach of Equinox. Roughly, for a function y = inox.transform(f)(x) is equivalent to g = lambda x: tree_mask(f(tree_unmask(x)))
y = tree_unmask(jax.transform(g)(tree_mask(x))) The lifting is not tailored to each transformation, which means that Inox transformations might be slightly slower than Equinox transformations. However, the simplicity and generality allows to keep exactly the same interface as the base transformation. For better performances, users should use These changes effectively make Inox a mini-version of Equinox, which I mention in the README along with Serket for the inspiration. Thanks again for the discussions!
|
Beta Was this translation helpful? Give feedback.
-
Hello @patrick-kidger, we can continue the discussion we started in jax-ml/jax#16170 here. I choose to do it in this repo rather than pollute the issues of Equinox. To summarize, we agree that automatically detecting static leaves would be better than the current Equinox behavior. We both propose an approach that is based on the tree.
My approach is to detect static leaves when flattening the tree. The advantage is that everything, from the types to the very structure of the tree, can be modified in-place, as users would expect from a Python object. The main drawback is that it does not work with the current implementation of
jax.vmap
, which assumes that the tree structure does not change when leaves are modified. In my opinion, this is actually a limitation of JAX and should/could be fixed. In fact, JAX already has a similar behavior withNone
objects.You propose to detect static leaves when attributes are set (in
__setattr__
). The advantage is thattree_map
(or equivalent composition oftree_flatten
andtree_unflatten
) would not change the tree structure. The main drawback is that in-place modifications of attributes would not work as users expect, as__getattr__
would create deep copies of the attribute. For instance,It might also add some overhead (although compiled away by
jax.jit
).Beta Was this translation helpful? Give feedback.
All reactions