Skip to content

Commit

Permalink
Merge pull request #879 from google-deepmind:fabianp-patch-4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618179011
  • Loading branch information
OptaxDev committed Mar 22, 2024
2 parents 207983d + 1dc1c4a commit 0927c15
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions optax/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def multi_transform(
import jax.numpy as jnp
def map_nested_fn(fn):
'''Recursively apply `fn` to the key-value pairs of a nested dict'''
'''Recursively apply `fn` to the key-value pairs of a nested dict.'''
def map_fn(nested_dict):
return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
for k, v in nested_dict.items()}
Expand Down Expand Up @@ -178,7 +178,7 @@ def map_fn(nested_dict):
param_labels)
If you would like to not optimize some parameters, you may wrap
``optax.multi_transform`` with :func:`optax.masked`.
:func:`optax.multi_transform` with :func:`optax.masked`.
Args:
transforms: A mapping from labels to transformations. Each transformation
Expand All @@ -191,7 +191,8 @@ def map_fn(nested_dict):
extra_arg fields with the same tree structure as params/updates.
Returns:
An ``optax.GradientTransformation``.
A :func:`optax.GradientTransformationExtraArgs` that implements an ``init``
and ``update`` function.
"""

transforms = {
Expand Down

0 comments on commit 0927c15

Please sign in to comment.