diff --git a/docs/experimental/nnx/nnx_basics.ipynb b/docs/experimental/nnx/nnx_basics.ipynb index 22fea1f92c..11e0897d2d 100644 --- a/docs/experimental/nnx/nnx_basics.ipynb +++ b/docs/experimental/nnx/nnx_basics.ipynb @@ -411,7 +411,7 @@ } ], "source": [ - "state, static = model.split()\n", + "static, state = model.split()\n", "\n", "print(f'{state = }\\n')\n", "print(f'{static = }'[:200] + '...')" @@ -450,7 +450,7 @@ "print(f'{model.count = }')\n", "\n", "# 1. Use split to create a pytree representation of the Module\n", - "state, static = model.split()\n", + "static, state = model.split()\n", "\n", "@jax.jit\n", "def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", @@ -532,7 +532,7 @@ ], "source": [ "# use Variable type filters to split into multiple States\n", - "params, counts, static = model.split(nnx.Param, Count)\n", + "static, params, counts = model.split(nnx.Param, Count)\n", "\n", "print(f'{params = }\\n')\n", "print(f'{counts = }')" diff --git a/docs/experimental/nnx/nnx_basics.md b/docs/experimental/nnx/nnx_basics.md index 405e274a08..65a20c6ba6 100644 --- a/docs/experimental/nnx/nnx_basics.md +++ b/docs/experimental/nnx/nnx_basics.md @@ -204,7 +204,7 @@ a Module graph, its analogous to JAX's `PyTreeDef`, and for convenience it implements an empty pytree. ```{code-cell} ipython3 -state, static = model.split() +static, state = model.split() print(f'{state = }\n') print(f'{static = }'[:200] + '...') @@ -222,7 +222,7 @@ updates from a transform back to the source object outside. print(f'{model.count = }') # 1. Use split to create a pytree representation of the Module -state, static = model.split() +static, state = model.split() @jax.jit def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array): @@ -269,7 +269,7 @@ Variable types as shown below. ```{code-cell} ipython3 # use Variable type filters to split into multiple States -params, counts, static = model.split(nnx.Param, Count) +static, params, counts = model.split(nnx.Param, Count) print(f'{params = }\n') print(f'{counts = }')