Skip to content

Commit

Permalink
Merge pull request #1153 from google-deepmind:expose-named_chain
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704475102
  • Loading branch information
OptaxDev committed Dec 10, 2024
2 parents 3d8c391 + 443eef3 commit f076ae1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/api/combining_optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ Combining Optimizers

.. autosummary::
chain
named_chain
multi_transform

Chain
~~~~~
.. autofunction:: chain
.. autofunction:: named_chain

Multi-transform
~~~~~~~~~~~~~~~
Expand Down
51 changes: 34 additions & 17 deletions optax/transforms/_combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def chain(
updates in the given order.
Args:
*args: a sequence of chainable (init_fn, update_fn) tuples.
*args: an arbitrary number of ``transform``-s of
:class:`GradientTransformation` or
:class:`GradientTransformationExtraArgs`.
Returns:
A :class:`GradientTransformationExtraArgs`, created by chaining the input
Expand All @@ -55,6 +57,18 @@ def chain(
>>> state = chained_transform.init(params)
>>> updates = {'a': -0.5}
>>> updates, new_state = chained_transform.update(updates, state, params)
An optimizer in the chain might require extra args:
>>> import optax
>>> opt1 = optax.scale(0.1) # scale incoming gradients
>>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update`
>>> chained_transform = optax.chain(opt1, opt2)
>>> state = chained_transform.init(0.5)
>>> extra_args = {"value": 1.0}
>>> updates, new_state = chained_transform.update(
... 0.7, state, 0.7, **extra_args # extra args for all transforms
... )
"""

transforms = [base.with_extra_args_support(t) for t in args]
Expand Down Expand Up @@ -85,13 +99,13 @@ def update_fn(updates, state, params=None, **extra_args):


def named_chain(
*transforms: tuple[str, base.GradientTransformation]
*args: tuple[str, base.GradientTransformation]
) -> base.GradientTransformationExtraArgs:
"""Chains optax gradient transformations.
"""Applies a list of named chainable update transformations.
A variant of :func:`optax.chain` that allows to name each transformation.
Here the ``transforms`` are ``(name, transformation)`` pairs, constituted of a
Here the ``args`` are ``(name, transformation)`` pairs, constituted of a
string ``name`` and an associated transformation ``transformation``. The
gradient transformation must be an instance of :class:`GradientTransformation`
or :class:`GradientTransformationExtraArgs`.
Expand All @@ -101,34 +115,37 @@ def named_chain(
with a given ``name`` can be easily retrieved as ``opt_state[name]``.
Args:
*transforms: an arbitrary number of ``(name, tx)`` pairs, constituted of a
string ``name`` and an associated transformation ``tx``. The latter is a
:class:`GradientTransformation` or
*args: an arbitrary number of ``(name, transform)`` pairs, constituted of a
string ``name`` and an associated transformation ``transform``. The latter
is a :class:`GradientTransformation` or
:class:`GradientTransformationExtraArgs`.
Returns:
A single (init_fn, update_fn) tuple.
Examples:
>>> # tx1 is a GradientTransformation with no extra_args.
>>> # tx2 is a GradientTransformationExtraArgs that requires `loss`.
>>> # tx3 is a GradientTransformationExtraArgs that requires `temperature`.
>>> tx = named_chain(('one', tx1), ('two', tx2), ('three', tx3))
>>> extra_args={'loss': 0.3, 'temperature': 0.01}
>>> tx.init(params)
>>> tx.update(grads, state, params, **extra_args)
>>> import optax
>>> opt1 = optax.scale(0.1) # scale incoming gradients
>>> opt2 = optax.polyak_sgd() # requires a `value` extra arg for `update`
>>> chained_transform = optax.named_chain(("scale", opt1), ("sgd", opt2))
>>> state = chained_transform.init(0.5)
>>> extra_args = {"value": 1.0}
>>> updates, new_state = chained_transform.update(
... 0.7, state, 0.7, **extra_args # extra args for all transforms
... )
>>> tuple(new_state.keys()) == ("scale", "sgd")
True
"""

names = [name for name, _ in transforms]
names = [name for name, _ in args]

if len(names) != len(set(names)):
raise ValueError(
f'Named transformations must have unique names, but got {names}'
)

transforms = [
(name, base.with_extra_args_support(t)) for name, t in transforms
(name, base.with_extra_args_support(t)) for name, t in args
]

def init_fn(params):
Expand Down

0 comments on commit f076ae1

Please sign in to comment.