Skip to content

Commit

Permalink
Make masking compatible with callable pytrees a la Equinox
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653116911
  • Loading branch information
vroulet authored and OptaxDev committed Jul 17, 2024
1 parent f21430a commit c323822
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
9 changes: 8 additions & 1 deletion optax/transforms/_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class MaskedNode(NamedTuple):
"""


def _mask_callable(
mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]]
):
callable_leaves = jtu.tree_leaves(jtu.tree_map(callable, mask))
return (len(callable_leaves) > 0) and all(callable_leaves) # pylint:disable=g-explicit-length-test


def masked(
inner: base.GradientTransformation,
mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]],
Expand Down Expand Up @@ -120,7 +127,7 @@ def init_fn(params):
return MaskedState(inner_state=inner.init(masked_params))

def update_fn(updates, state, params=None, **extra_args):
mask_tree = mask(updates) if callable(mask) else mask
mask_tree = mask(updates) if _mask_callable(mask) else mask
masked_extra_args = maybe_mask_values(extra_args, updates, mask_tree)
masked_updates = mask_pytree(updates, mask_tree)
masked_params = None if params is None else mask_pytree(params, mask_tree)
Expand Down
28 changes: 27 additions & 1 deletion optax/transforms/_masking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
"""Tests for optax.transforms._masking."""

import copy
import dataclasses
from typing import cast

from absl.testing import absltest
from absl.testing import parameterized

import chex
from jax import tree_util as jtu
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from optax._src import alias
Expand Down Expand Up @@ -343,6 +345,30 @@ def test_masked_state_structure(self):
}
chex.assert_trees_all_equal_structs(trace, expected_trace)

def test_mask_compatible_callable_pytree(self):
"""Test if mask is compatible with callable pytrees a la equinox."""
# https://github.com/google-deepmind/optax/issues/913

# Define a dataclass with a callable
@dataclasses.dataclass
class _MaskCompatibleModule():
w: jax.Array

def __call__(self, x):
return self.w.dot(x)

# Make it a pytree by registring it as a pytree node
jtu.register_pytree_node(
_MaskCompatibleModule,
lambda m: (m.w, None),
lambda _, c: _MaskCompatibleModule(c)
)
module = _MaskCompatibleModule(jnp.array([1., 2., 3.]))

# The module is then callable (like the Module class in equinox)
self.assertTrue(callable(module))
# But it won't be treated as a callable by the masking logic
self.assertFalse(_masking._mask_callable(module))

if __name__ == '__main__':
absltest.main()

0 comments on commit c323822

Please sign in to comment.