diff --git a/optax/_src/combine.py b/optax/_src/combine.py index 4c7585be6..b30498bb7 100644 --- a/optax/_src/combine.py +++ b/optax/_src/combine.py @@ -147,7 +147,7 @@ def multi_transform( import jax.numpy as jnp def map_nested_fn(fn): - '''Recursively apply `fn` to the key-value pairs of a nested dict''' + '''Recursively apply `fn` to the key-value pairs of a nested dict.''' def map_fn(nested_dict): return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) for k, v in nested_dict.items()} @@ -178,7 +178,7 @@ def map_fn(nested_dict): param_labels) If you would like to not optimize some parameters, you may wrap - ``optax.multi_transform`` with :func:`optax.masked`. + :func:`optax.multi_transform` with :func:`optax.masked`. Args: transforms: A mapping from labels to transformations. Each transformation @@ -191,7 +191,8 @@ def map_fn(nested_dict): extra_arg fields with the same tree structure as params/updates. Returns: - An ``optax.GradientTransformation``. + A :func:`optax.GradientTransformationExtraArgs` that implements an ``init`` + and ``update`` function. """ transforms = {