From 8f122f602b1fb2f310eac67775d2a1566d8c0744 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 25 May 2023 22:19:21 +0000 Subject: [PATCH] eagerly bind submodules --- flax/linen/module.py | 64 ++++++++++-- pyproject.toml | 6 +- tests/linen/linen_module_test.py | 174 +++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 9 deletions(-) diff --git a/flax/linen/module.py b/flax/linen/module.py index 6289a0cf9e..68e78f38ea 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -490,6 +490,10 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]: return method_or_fn +def _map_submodules(fn: Callable[['Module'], Any], tree): + """Map a function over all submodules in a tree.""" + g = lambda _, x: fn(x) if isinstance(x, Module) else x + return _freeze_attr(_map_over_modules_in_tree(g, tree)) class SetupState(enum.IntEnum): # setup() has not been called. @@ -1000,6 +1004,12 @@ def __post_init__(self) -> None: else: raise ValueError('parent must be None, Module or Scope') + # eagerly bind submodules if scope is available + if self.scope is not None: + for field in dataclasses.fields(self): + if field.name not in ('parent', 'name') and field.init: + self._register_submodules(field.name, getattr(self, field.name)) + self._state.is_initialized = True def __repr__(self) -> str: @@ -1048,8 +1058,8 @@ def _register_submodules(self, name, val): _caches[root] = cache queue = [] preserve_adopted_names = config.flax_preserve_adopted_names - if hasattr(self, 'preserve_adopted_names'): - preserve_adopted_names = self.preserve_adopted_names + if hasattr(type(self), 'preserve_adopted_names'): + preserve_adopted_names = type(self).preserve_adopted_names def adopt_attr_modules(cache, queue, suffix, subvalue): if isinstance(subvalue, Module): adopted_name = None @@ -1092,7 +1102,7 @@ def _try_setup(self, shallow: bool = False) -> None: # not call the user's setup. This avoids running before a # transformation. for field in dataclasses.fields(self): - if field.name != 'parent' and field.init: + if field.name not in ('parent', 'name') and field.init: self._register_submodules(field.name, getattr(self, field.name)) if not shallow: self.setup() @@ -1137,23 +1147,61 @@ def _name_taken(self, @property def _initialization_allowed(self): - return self._state.in_setup or self._state.in_compact_method + return (not self._state.is_initialized # allow eager attachment in post-init + or self._state.in_setup + or self._state.in_compact_method) def clone(self: M, *, parent: Optional[Union[Scope, 'Module']] = None, + _deep_clone: Union[bool, weakref.WeakValueDictionary] = False, **updates) -> M: """Creates a clone of this Module, with optionally updated arguments. Args: parent: The parent of the clone. The clone will have no parent if no explicit parent is specified. + _deep_clone: A boolean or a weak value dictionary to control deep cloning + of submodules. If True, submodules will be cloned recursively. If a + weak value dictionary is passed, it will be used to cache cloned + submodules. This flag is used by init/apply/bind to avoid scope + leakage. **updates: Attribute updates. Returns: A clone of the this Module with the updated attributes and parent. """ attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init} + attrs.update(parent=parent, **updates) - return self.__class__(**attrs) + + # Here we implement deep cloning of submodules, this is necessary to avoid scope leakage + # from external submodules into init/apply/bind while preserving sharing-by-reference + # relationships between submodules. + if _deep_clone != False: + # We use a weak value dictionary to cache cloned submodules. When a shared + # submodule is cloned, its only cloned once else its fetched from the cache. + cache = weakref.WeakValueDictionary() if isinstance(_deep_clone, bool) else _deep_clone + def clone_fn(m: Module) -> Module: + if hasattr(m, '_id'): + key = m._id + if key in cache: + return cache[key] + else: + clone = m.clone(_deep_clone=cache) + cache[key] = clone + return clone + else: + # If the module doesn't have an _id attribute it could be a mock object + # so we return it as is. + return m + + # _map_submodules will map over all submodules inside attrs + # value here can be any pytree, non-module values are ignored + for field_name, value in attrs.items(): + attrs[field_name] = _map_submodules(clone_fn, value) + + module = self.__class__(**attrs) + + return module def variable(self, col: str, name: str, init_fn: Optional[Callable[..., Any]] = None, @@ -1371,7 +1419,7 @@ def __call__(self, x): del args scope = core.bind(variables, rngs=rngs, mutable=mutable) - return self.clone(parent=scope) + return self.clone(parent=scope, _deep_clone=True) def unbind(self: M) -> Tuple[M, VariableDict]: """Returns an unbound copy of a Module and its variables. @@ -2057,7 +2105,7 @@ def f(foo, x): def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) try: - return fn(module.clone(parent=scope), *args, **kwargs) + return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) finally: _context.capture_stack.pop() @@ -2118,7 +2166,7 @@ def f(foo, x): def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) try: - return fn(module.clone(parent=scope), *args, **kwargs) + return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) finally: _context.capture_stack.pop() diff --git a/pyproject.toml b/pyproject.toml index 0ccbd93393..4f6d9e0505 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,7 +131,11 @@ filterwarnings = [ "ignore:file system plugins are not loaded.*:UserWarning", "ignore:unable to load libtensorflow_io_plugins.so.*:UserWarning", # Remove this after next Optax release after 3/27/2023 - "ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning" + "ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning", + # DeprecationWarning: pkg_resources is deprecated as an API + "ignore:.*pkg_resources is deprecated.*:DeprecationWarning", + # DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. + "ignore:.*Deprecated call to `pkg_resources.declare_namespace.*:DeprecationWarning", ] [tool.coverage.report] diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index cff7946157..d7e3a4cbc1 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2000,6 +2000,180 @@ def __call__(self): 'Trying to access a property that'): foo.apply({}) + def test_nested_external_modules(self): + class Baz(nn.Module): + a: int + + def setup(self): + self.b = self.param('b', lambda k: 2) + + def __call__(self, x): + return x + self.a * self.b + + class Bar(nn.Module): + baz: Baz + + def __call__(self, x): + return self.baz(x) + + class Foo(nn.Module): + def setup(self): + self.bar = Bar(baz=Baz(a=1)) + + def __call__(self, x): + return self.bar.baz(x) + + module = Foo() + y, variables = module.init_with_output(jax.random.PRNGKey(0), 1) + self.assertEqual(y, 3) + + def test_getattribute_triggers_setup(self): + class B(nn.Module): + def setup(self): + self.p1 = self.param('p1', lambda k: jnp.ones((2,))) + def fn1(self, x): + return self.p1 + x + class A(nn.Module): + b: nn.Module + def __call__(self, x): + return self.b.fn1(x) + a = A(b=B()) + k = random.PRNGKey(0) + x = jnp.zeros((2,)) + vs = nn.init(lambda a,x: a(x), a)(k, x) + y = nn.apply(lambda a,x: a.b.fn1(x), a)(vs, x) + np.testing.assert_array_equal(y, jnp.ones((2,))) + + def test_nested_sequential_in_call(self): + class Foo(nn.Module): + def setup(self): + self.seq = nn.Sequential([nn.Dense(10) for i in range(10)]) + + def __call__(self, x): + # try calling only the first layer + return self.seq.layers[0](x) + + module = Foo() + variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 10))) + + def test_setup_called_bounded_submodules(self): + module = nn.Sequential([ + nn.Sequential([ + nn.Dense(2), + nn.relu, + nn.Dense(2), + ]), + nn.relu, + nn.Dense(2), + ]) + x = jnp.ones((1, 3)) + variables = module.init(jax.random.PRNGKey(0), x) + bound_module = module.bind(variables) + + self.assertIsNotNone(bound_module.layers[0].layers[0].scope) + self.assertIsNotNone(bound_module.layers[0].layers[2].scope) + self.assertIsNotNone(bound_module.layers[2].scope) + + def test_call_bounded_toplevel_mutable(self): + class Bar(nn.Module): + a: int + + def setup(self): + self.b = self.param('b', lambda k: 1) + + def __call__(self, x): + return x + self.a * self.b + + class Foo(nn.Module): + bars: Sequence[Bar] + + def __call__(self, x): + for bar in self.bars: + x = bar(x) + return x + + + module = Foo(bars=[]) + module.bars = [Bar(a=1)] + + variables = module.init(jax.random.PRNGKey(0), jnp.ones(())) + bound_module = module.bind(variables) + + bar1 = bound_module.bars[0] + self.assertIsNotNone(bar1.scope) + + def test_nested_init(self): + class Baz(nn.Module): + a: int + + def setup(self): + self.b = self.param('b', lambda k: jnp.ones(())) + + def __call__(self, x): + return x + self.a * self.b + + class Bar(nn.Module): + baz: Baz + + def setup(self): + a = 1 + + def __call__(self, x): + return self.baz(x) + + class Foo(nn.Module): + + def setup(self): + self.bar: Bar = Bar(baz=Baz(a=1)) + + def __call__(self, x): + # y = self.bar(x) + y, bar_vars = self.bar.init_with_output(jax.random.PRNGKey(0), x) + return y, bar_vars + + # create foo + module = Foo() + + # run foo + (y, bar_vars), variables = module.init_with_output( + jax.random.PRNGKey(0), jnp.ones(())) + + self.assertIn('params', bar_vars) + + def test_nested_shared(self): + class Shared(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(1)(x) + + class Unshared(nn.Module): + shared: nn.Module + def __call__(self, x): + return self.shared(x) + + class Super(nn.Module): + a: nn.Module + b: nn.Module + def run_a(self, x): + return self.a(x) + def run_b(self, x): + return self.b(x) + def __call__(self, x): + return self.a(x) + self.b(x) + + + sh = Shared() + a = Unshared(shared=sh) + b = Unshared(shared=sh) + module = Super(a=a, b=b) + + rng = jax.random.PRNGKey(0) + params = module.init(rng, jnp.ones(1))["params"] + + module.apply({"params": params}, jnp.ones(1)) # works as expected + module.apply({"params": params}, jnp.ones(1), method="run_a") # works as expected + module.apply({"params": params}, jnp.ones(1), method="run_b") # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0" + def test_repr(self): class Base1(nn.Module):