diff --git a/optax/_src/combine.py b/optax/_src/combine.py
index 4c7585be6..b30498bb7 100644
--- a/optax/_src/combine.py
+++ b/optax/_src/combine.py
@@ -147,7 +147,7 @@ def multi_transform(
     import jax.numpy as jnp
 
     def map_nested_fn(fn):
-      '''Recursively apply `fn` to the key-value pairs of a nested dict'''
+      '''Recursively apply `fn` to the key-value pairs of a nested dict.'''
       def map_fn(nested_dict):
         return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
                 for k, v in nested_dict.items()}
@@ -178,7 +178,7 @@ def map_fn(nested_dict):
         param_labels)
 
   If you would like to not optimize some parameters, you may wrap
-  ``optax.multi_transform`` with :func:`optax.masked`.
+  :func:`optax.multi_transform` with :func:`optax.masked`.
 
   Args:
     transforms: A mapping from labels to transformations. Each transformation
@@ -191,7 +191,8 @@ def map_fn(nested_dict):
       extra_arg fields with the same tree structure as params/updates.
 
   Returns:
-    An ``optax.GradientTransformation``.
+    A :func:`optax.GradientTransformationExtraArgs` that implements an ``init``
+    and ``update`` function.
   """
 
   transforms = {