diff --git a/flax/linen/module.py b/flax/linen/module.py index e46f1db218..71dc192368 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -482,6 +482,21 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]: return method_or_fn +def _tree_has_modules(x): + if isinstance(x, Module): + return True + elif isinstance(x, (int, float, bool, str, np.ndarray, jnp.ndarray)): + return False + else: + return any(isinstance(v, Module) for v in jax.tree_util.tree_leaves(x)) + +def _map_submodules(fn: Callable[['Module'], Any], tree): + """Map a function over all submodules in a tree.""" + def _map_fn(x): + if isinstance(x, Module): + return fn(x) + return x + return jax.tree_map(_map_fn, tree) class SetupState(enum.IntEnum): # setup() has not been called. @@ -897,6 +912,15 @@ def __getattr__(self, name: str) -> Any: 'are only accessible from inside \'init\' or \'apply\'.') raise AttributeError(msg) + def __getattribute__(self, name): + """Call setup() before accessing any submodule attributes.""" + # NB: all code here is very "hot" and will be run very frequently. + if ('_submodule_dataclass_fields' in object.__getattribute__(self, '__dict__') + and name in object.__getattribute__(self, '_submodule_dataclass_fields')): + object.__getattribute__(self, '_try_setup')() + # always run original python __getattribute__ + return object.__getattribute__(self, name) + def __dir__(self) -> List[str]: """Call setup() before listing attributes.""" self._try_setup() @@ -917,6 +941,15 @@ def __post_init__(self) -> None: if self.parent is _unspecified_parent: # pytype: disable=attribute-error object.__setattr__(self, 'parent', _context.module_stack[-1]) + # find all dataclass fields that have submodules + _submodule_dataclass_fields = tuple( + field.name for field in dataclasses.fields(self) + if field.name not in ('parent', 'name') + if field.name in self.__dict__ # ignore fields that have not been set + if _tree_has_modules(getattr(self, field.name)) + ) + object.__setattr__(self, '_submodule_dataclass_fields', _submodule_dataclass_fields) + # Initialization is deferred for top level Modules or any other "orphan" # Modules until attachment by __setattr__ i.e. MyModule(..., parent=None) if self.parent is None: @@ -1047,7 +1080,7 @@ def _try_setup(self, shallow: bool = False) -> None: # not call the user's setup. This avoids running before a # transformation. for field in dataclasses.fields(self): - if field.name != 'parent' and field.init: + if field.name not in ('parent', 'name') and field.init: self._register_submodules(field.name, getattr(self, field.name)) if not shallow: self.setup() @@ -1095,19 +1128,76 @@ def _initialization_allowed(self): def clone(self: M, *, parent: Optional[Union[Scope, 'Module']] = None, + _deep_clone: Union[bool, weakref.WeakValueDictionary] = False, **updates) -> M: """Creates a clone of this Module, with optionally updated arguments. Args: parent: The parent of the clone. The clone will have no parent if no explicit parent is specified. + _deep_clone: A boolean or a weak value dictionary to control deep cloning + of submodules. If True, submodules will be cloned recursively. If a + weak value dictionary is passed, it will be used to cache cloned + submodules. This flag is used by init/apply/bind to avoid scope + leakage. **updates: Attribute updates. Returns: A clone of the this Module with the updated attributes and parent. """ attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init} + attrs.update(parent=parent, **updates) - return self.__class__(**attrs) + + # Here we implement deep cloning of submodules, this is necessary to avoid scope leakage + # from external submodules into init/apply/bind while preserving sharing-by-reference + # relationships between submodules. + if _deep_clone != False: + # We use a weak value dictionary to cache cloned submodules. When a shared + # submodule is cloned, its only cloned once else its fetched from the cache. + cache = weakref.WeakValueDictionary() if isinstance(_deep_clone, bool) else _deep_clone + def clone_fn(m: Module) -> Module: + key = m._id + if key in cache: + return cache[key] + else: + clone = m.clone(_deep_clone=cache) + cache[key] = clone + return clone + + # _map_submodules will map over all submodules inside attrs + # value here can be any pytree, non-module values are ignored + for field_name, value in attrs.items(): + attrs[field_name] = _map_submodules(clone_fn, value) + + module = self.__class__(**attrs) + + # We need to register submodules recursively after cloning to ensure that + # scopes for input submodules are eagerly created, this solves problems + # that arise when sharing external submodules. + if _deep_clone is True: + module._recursive_register_submodules() + + return module + + def _recursive_register_submodules(self): + """Recursively registers submodules in this module and its children.""" + + # We only register submodules that are passed from the outside. + # These are found in the _submodule_dataclass_fields attribute. + for field_name in self._submodule_dataclass_fields: + value = self.__dict__[field_name] + current_in_setup = self._state.in_setup + try: + # We are temporarily setting in_setup to True to trick + # _register_submodules thinking its inside setup, else it will + # error. Maybe we can create a new state for this? + self._state.in_setup = True + self._register_submodules(field_name, value) + finally: + self._state.in_setup = current_in_setup + + value = self.__dict__[field_name] + _map_submodules(lambda m: m._recursive_register_submodules(), value) def variable(self, col: str, name: str, init_fn: Optional[Callable[..., Any]] = None, @@ -1323,7 +1413,7 @@ def __call__(self, x): del args scope = core.bind(variables, rngs=rngs, mutable=mutable) - return self.clone(parent=scope) + return self.clone(parent=scope, _deep_clone=True) def unbind(self: M) -> Tuple[M, VariableDict]: """Returns an unbound copy of a Module and its variables. @@ -2009,7 +2099,7 @@ def f(foo, x): def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) try: - return fn(module.clone(parent=scope), *args, **kwargs) + return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) finally: _context.capture_stack.pop() @@ -2070,7 +2160,7 @@ def f(foo, x): def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) try: - return fn(module.clone(parent=scope), *args, **kwargs) + return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) finally: _context.capture_stack.pop() diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index d3056d985b..695c974619 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -21,10 +21,11 @@ import gc import inspect import operator -from typing import (Any, Callable, Generic, Mapping, NamedTuple, Sequence, +from typing import (Any, Callable, Generic, List, Mapping, NamedTuple, Sequence, Tuple, TypeVar, get_type_hints) from absl.testing import absltest +import pytest from flax import config from flax import errors from flax import linen as nn @@ -1994,6 +1995,194 @@ def __call__(self): 'Trying to access a property that'): foo.apply({}) + def test_basic_setup_test(self): + class Foo(nn.Module): + def setup(self): + self.b = 10 + + def __call__(self, x): + return self.b + x + + module = Foo() + variables = module.bind({}) + + def test_nested_external_modules(self): + class Baz(nn.Module): + a: int + + def setup(self): + self.b = self.param('b', lambda k: 2) + + def __call__(self, x): + return x + self.a * self.b + + class Bar(nn.Module): + baz: Baz + + def __call__(self, x): + return self.baz(x) + + class Foo(nn.Module): + def setup(self): + self.bar = Bar(baz=Baz(a=1)) + + def __call__(self, x): + return self.bar.baz(x) + + module = Foo() + y, variables = module.init_with_output(jax.random.PRNGKey(0), 1) + self.assertEqual(y, 3) + + def test_getattribute_triggers_setup(self): + class B(nn.Module): + def setup(self): + self.p1 = self.param('p1', lambda k: jnp.ones((2,))) + def fn1(self, x): + return self.p1 + x + class A(nn.Module): + b: nn.Module + def __call__(self, x): + return self.b.fn1(x) + a = A(b=B()) + k = random.PRNGKey(0) + x = jnp.zeros((2,)) + vs = nn.init(lambda a,x: a(x), a)(k, x) + y = nn.apply(lambda a,x: a.b.fn1(x), a)(vs, x) + np.testing.assert_array_equal(y, jnp.ones((2,))) + + def test_nested_sequential_in_call(self): + class Foo(nn.Module): + def setup(self): + self.seq = nn.Sequential([nn.Dense(10) for i in range(10)]) + + def __call__(self, x): + # try calling only the first layer + return self.seq.layers[0](x) + + + module = Foo() + variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 10))) + + def test_setup_called_bounded_submodules(self): + module = nn.Sequential([ + nn.Sequential([ + nn.Dense(2), + nn.relu, + nn.Dense(2), + ]), + nn.relu, + nn.Dense(2), + ]) + x = jnp.ones((1, 3)) + variables = module.init(jax.random.PRNGKey(0), x) + bound_module = module.bind(variables) + + self.assertIsNotNone(bound_module.layers[0].layers[0].scope) + self.assertIsNotNone(bound_module.layers[0].layers[2].scope) + self.assertIsNotNone(bound_module.layers[2].scope) + + def test_call_bounded_toplevel_mutable(self): + class Bar(nn.Module): + a: int + + def setup(self): + self.b = self.param('b', lambda k: 1) + + def __call__(self, x): + return x + self.a * self.b + + class Foo(nn.Module): + bars: Sequence[Bar] + + def __call__(self, x): + for bar in self.bars: + x = bar(x) + return x + + + module = Foo(bars=[]) + module.bars = [Bar(a=1)] + + variables = module.init(jax.random.PRNGKey(0), jnp.ones(())) + bound_module = module.bind(variables) + + bar1 = bound_module.bars[0] + self.assertIsNotNone(bar1.scope) + + # @pytest.mark.skip(reason="Leaving it here to tackle it later") + def test_nested_shared(self): + class Shared(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(1)(x) + + class Unshared(nn.Module): + shared: nn.Module + def __call__(self, x): + return self.shared(x) + + class Super(nn.Module): + a: nn.Module + b: nn.Module + def run_a(self, x): + return self.a(x) + def run_b(self, x): + return self.b(x) + def __call__(self, x): + return self.a(x) + self.b(x) + + + sh = Shared() + a = Unshared(shared=sh) + b = Unshared(shared=sh) + module = Super(a=a, b=b) + + rng = jax.random.PRNGKey(0) + params = module.init(rng, jnp.ones(1))["params"] + + module.apply({"params": params}, jnp.ones(1)) # works as expected + module.apply({"params": params}, jnp.ones(1), method="run_a") # works as expected + module.apply({"params": params}, jnp.ones(1), method="run_b") # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0" + + def test_nested_init(self): + class Baz(nn.Module): + a: int + + def setup(self): + self.b = self.param('b', lambda k: jnp.ones(())) + + def __call__(self, x): + return x + self.a * self.b + + class Bar(nn.Module): + baz: Baz + + def setup(self): + a = 1 + + def __call__(self, x): + return self.baz(x) + + class Foo(nn.Module): + + def setup(self): + self.bar: Bar = Bar(baz=Baz(a=1)) + + def __call__(self, x): + # y = self.bar(x) + y, bar_vars = self.bar.init_with_output(jax.random.PRNGKey(0), x) + return y, bar_vars + + # create foo + module = Foo() + + # run foo + (y, bar_vars), variables = module.init_with_output( + jax.random.PRNGKey(0), jnp.ones(())) + + self.assertIn('params', bar_vars) + + def test_repr(self): class Base1(nn.Module):