Skip to content

Commit

Permalink
Fix masking incompatibility with Equinox
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653116911
  • Loading branch information
vroulet authored and OptaxDev committed Jul 17, 2024
1 parent ed6062d commit 7d315cf
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion optax/transforms/_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class MaskedNode(NamedTuple):
"""


def mask_callable(
mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]]
):
callable_leaves = jtu.tree_leaves(jtu.tree_map(callable, mask))
return (len(callable_leaves) > 0) and all(callable_leaves) # pylint:disable=g-explicit-length-test


def masked(
inner: base.GradientTransformation,
mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]],
Expand Down Expand Up @@ -120,7 +127,7 @@ def init_fn(params):
return MaskedState(inner_state=inner.init(masked_params))

def update_fn(updates, state, params=None, **extra_args):
mask_tree = mask(updates) if callable(mask) else mask
mask_tree = mask(updates) if mask_callable(mask) else mask
masked_extra_args = maybe_mask_values(extra_args, updates, mask_tree)
masked_updates = mask_pytree(updates, mask_tree)
masked_params = None if params is None else mask_pytree(params, mask_tree)
Expand Down

0 comments on commit 7d315cf

Please sign in to comment.