Skip to content

Commit

Permalink
[nnx] remove State
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 5, 2024
1 parent cc75368 commit 223bb05
Show file tree
Hide file tree
Showing 38 changed files with 582 additions and 738 deletions.
10 changes: 5 additions & 5 deletions docs_nnx/guides/checkpointing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -289,15 +289,15 @@
],
"source": [
"# Save as pure dict\n",
"pure_dict_state = state.to_pure_dict()\n",
"pure_dict_state = nnx.to_pure_dict(state)\n",
"nnx.display(pure_dict_state)\n",
"checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)\n",
"\n",
"# Restore as a pure dictionary.\n",
"restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n",
"abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graphdef, abstract_state = nnx.split(abstract_model)\n",
"abstract_state.replace_by_pure_dict(restored_pure_dict)\n",
"nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)\n",
"model = nnx.merge(graphdef, abstract_state)\n",
"assert model(x).shape == (3, 4) # The model still works!"
]
Expand Down Expand Up @@ -325,7 +325,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -379,7 +379,7 @@
"# Same restore code as above.\n",
"abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graphdef, abstract_state = nnx.split(abstract_model)\n",
"abstract_state.replace_by_pure_dict(restored_pure_dict)\n",
"nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)\n",
"model = nnx.merge(graphdef, abstract_state)\n",
"assert model(x).shape == (3, 4) # The new model works!\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ When interacting with checkpoint libraries (like Orbax), you may prefer to work

```{code-cell} ipython3
# Save as pure dict
pure_dict_state = state.to_pure_dict()
pure_dict_state = nnx.to_pure_dict(state)
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
```
Expand Down Expand Up @@ -181,7 +181,7 @@ restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))
# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The new model works!
Expand Down
97 changes: 41 additions & 56 deletions docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
" )\n",
"})\n",
"batch_stats = State({\n",
" 'b': VariableState(\n",
" type=BatchStat,\n",
" value=True\n",
" )\n",
"})\n"
"params = {'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
")}\n",
"batch_stats = {'b': VariableState(\n",
" type=BatchStat,\n",
" value=True\n",
")}\n"
]
}
],
Expand Down Expand Up @@ -203,18 +199,18 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"id": "7e065fa9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"is_param = OfType(<class 'flax.nnx.nnx.variables.Param'>)\n",
"is_param = OfType(<class 'flax.nnx.variablelib.Param'>)\n",
"everything = Everything()\n",
"nothing = Nothing()\n",
"params_or_dropout = Any(OfType(<class 'flax.nnx.nnx.variables.Param'>), WithTag('dropout'))\n"
"params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))\n"
]
}
],
Expand Down Expand Up @@ -248,26 +244,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "068208fc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
" )\n",
"})\n",
"batch_stats = State({\n",
" 'b': VariableState(\n",
" type=BatchStat,\n",
" value=True\n",
" )\n",
"})\n"
"params = {'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
")}\n",
"batch_stats = {'b': VariableState(\n",
" type=BatchStat,\n",
" value=True\n",
")}\n"
]
}
],
Expand All @@ -280,7 +272,7 @@
" predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n",
" flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n",
"\n",
" for path, value in state.flat_state():\n",
" for path, value in nnx.to_flat_state(state):\n",
" for i, predicate in enumerate(predicates):\n",
" if predicate(path, value):\n",
" flat_states[i][path] = value\n",
Expand All @@ -289,7 +281,7 @@
" raise ValueError(f'No filter matched {path = } {value = }')\n",
"\n",
" states: tuple[nnx.GraphState, ...] = tuple(\n",
" nnx.State.from_flat_path(flat_state) for flat_state in flat_states\n",
" nnx.from_flat_state(flat_state) for flat_state in flat_states\n",
" )\n",
" return graphdef, *states\n",
"\n",
Expand Down Expand Up @@ -317,25 +309,22 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"id": "014da4d4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
" ),\n",
" 'b': VariableState(\n",
" type=SpecialParam,\n",
" value=0\n",
" )\n",
"})\n",
"special_params = State({})\n"
"params = {'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
"), 'b': VariableState(\n",
" type=SpecialParam,\n",
" value=0\n",
")}\n",
"special_params = {}\n"
]
}
],
Expand Down Expand Up @@ -365,26 +354,22 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"id": "a2ebf5b2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"params = State({\n",
" 'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
" )\n",
"})\n",
"special_params = State({\n",
" 'b': VariableState(\n",
" type=SpecialParam,\n",
" value=0\n",
" )\n",
"})\n"
"params = {'a': VariableState(\n",
" type=Param,\n",
" value=0\n",
")}\n",
"special_params = {'b': VariableState(\n",
" type=SpecialParam,\n",
" value=0\n",
")}\n"
]
}
],
Expand All @@ -409,7 +394,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/filters_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def split(node, *filters):
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state():
for path, value in nnx.to_flat_state(state):
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
Expand All @@ -154,7 +154,7 @@ def split(node, *filters):
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
nnx.from_flat_state(flat_state) for flat_state in flat_states
)
return graphdef, *states
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -501,8 +501,8 @@
")\n",
"loaded_sharded = checkpointer.restore(path / 'checkpoint_name',\n",
" target=abs_state)\n",
"jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)\n",
"jax.debug.visualize_array_sharding(loaded_sharded.w2.value)"
"jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)\n",
"jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ abs_state = jax.tree.map(
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
target=abs_state)
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)
jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)
```

## Compile the training loop
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ The dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
nnx.update(model, nnx.merge_state(params, rest))

.. testcode:: Haiku
:hide:
Expand Down Expand Up @@ -378,7 +378,7 @@ The parameter structure is as follows:
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
{
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
Expand All @@ -387,7 +387,7 @@ The parameter structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
}
To call those custom methods:
Expand Down Expand Up @@ -634,14 +634,14 @@ Now inspect the variable pytree on both sides:
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
{
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
}
Top-level Haiku functions vs top-level Flax modules
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/linen_to_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
nnx.update(model, nnx.merge_state(params, rest))
.. testcode:: Linen
:hide:
Expand Down Expand Up @@ -389,7 +389,7 @@ The variable structure is as follows:
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
{
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
Expand All @@ -398,7 +398,7 @@ The variable structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
}
To call methods other than ``__call__``:

Expand Down Expand Up @@ -644,14 +644,14 @@ Now inspect the variable pytree on both sides:
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
{
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
}
Using ``TrainState`` in Flax NNX
Expand Down
Loading

0 comments on commit 223bb05

Please sign in to comment.