Skip to content

Commit

Permalink
Extend capabilities of tree_get, tree_set.
Browse files Browse the repository at this point in the history
1. Enable tree_get, tree_set to filter for the name of a named tuple in the path to a key (hence filter the name of a state in a chained transformation).
This enables distinguishing for attributes identical in two different states except that the names of the states are different.

2. Enable tree_get, tree_set to fetch or set named tuples by the name of the name tuple. This is handy to fetch a given state in the overall state of a chained optimizer.

PiperOrigin-RevId: 617976030
  • Loading branch information
vroulet authored and OptaxDev committed Mar 22, 2024
1 parent 207983d commit 60e8710
Show file tree
Hide file tree
Showing 4 changed files with 540 additions and 257 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Tree
.. currentmodule:: optax.tree_utils

.. autosummary::
NamedTupleKey
tree_add
tree_add_scalar_mul
tree_div
Expand All @@ -112,6 +113,10 @@ Tree
tree_vdot
tree_zeros_like

NamedTupleKey
~~~~~~~~~~~~~
.. autoclass:: tree_add

Tree add
~~~~~~~~
.. autofunction:: tree_add
Expand Down
2 changes: 2 additions & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# ==============================================================================
"""The tree_utils sub-package."""

from optax.tree_utils._state_utils import NamedTupleKey
from optax.tree_utils._state_utils import tree_get
from optax.tree_utils._state_utils import tree_get_all_with_path
from optax.tree_utils._state_utils import tree_map_params
from optax.tree_utils._state_utils import tree_set


from optax.tree_utils._tree_math import tree_add
from optax.tree_utils._tree_math import tree_add_scalar_mul
from optax.tree_utils._tree_math import tree_div
Expand Down
Loading

0 comments on commit 60e8710

Please sign in to comment.