You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/home/neil/src/cmm/a.py", line 35, in <module>
vmap(f, in_axes=(None, 0))(module, z)
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/api.py", line 1221, in vmap_f
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/api_util.py", line 400, in flatten_axes
dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/jax/_src/tree_util.py", line 84, in tree_unflatten
return treedef.unflatten(leaves)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/module.py", line 444, in _module_unflatten
return graph.merge(graphdef, State(zip(paths, variables)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 1306, in merge
node, _ = unflatten(graphdef, state)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 471, in unflatten
node = _graph_unflatten(
^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 628, in _graph_unflatten
children = _get_children()
^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 569, in _get_children
children[key] = _graph_unflatten(
^^^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 628, in _graph_unflatten
children = _get_children()
^^^^^^^^^^^^^^^
File "/home/neil/src/cmm/.venv/lib/python3.12/site-packages/flax/nnx/nnx/graph.py", line 575, in _get_children
raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')
ValueError: Expected a leaf for 'epsilon', but got <object object at 0x7ccaa815f050>
Assigning a float to epsilon makes the problem disappear.
(Tested on main and latest.)
The text was updated successfully, but these errors were encountered:
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142.
```python
class Foo(nnx.Module):
def __init__(self):
self.a = jnp.array(1) # no longer allowed, instead...
self.b = nnx.Param(jnp.array(1)) # just use Variables
```
Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally.
PiperOrigin-RevId: 671372717
gives
Assigning a
float
toepsilon
makes the problem disappear.(Tested on main and latest.)
The text was updated successfully, but these errors were encountered: