Skip to content

Commit

Permalink
improve submodule binding logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Feb 22, 2023
1 parent 26d46e4 commit 1d67d4e
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 6 deletions.
100 changes: 95 additions & 5 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,21 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]:

return method_or_fn

def _tree_has_modules(x):
if isinstance(x, Module):
return True
elif isinstance(x, (int, float, bool, str, np.ndarray, jnp.ndarray)):
return False
else:
return any(isinstance(v, Module) for v in jax.tree_util.tree_leaves(x))

def _map_submodules(fn: Callable[['Module'], Any], tree):
"""Map a function over all submodules in a tree."""
def _map_fn(x):
if isinstance(x, Module):
return fn(x)
return x
return jax.tree_map(_map_fn, tree)

class SetupState(enum.IntEnum):
# setup() has not been called.
Expand Down Expand Up @@ -897,6 +912,15 @@ def __getattr__(self, name: str) -> Any:
'are only accessible from inside \'init\' or \'apply\'.')
raise AttributeError(msg)

def __getattribute__(self, name):
"""Call setup() before accessing any submodule attributes."""
# NB: all code here is very "hot" and will be run very frequently.
if ('_submodule_dataclass_fields' in object.__getattribute__(self, '__dict__')
and name in object.__getattribute__(self, '_submodule_dataclass_fields')):
object.__getattribute__(self, '_try_setup')()
# always run original python __getattribute__
return object.__getattribute__(self, name)

def __dir__(self) -> List[str]:
"""Call setup() before listing attributes."""
self._try_setup()
Expand All @@ -917,6 +941,15 @@ def __post_init__(self) -> None:
if self.parent is _unspecified_parent: # pytype: disable=attribute-error
object.__setattr__(self, 'parent', _context.module_stack[-1])

# find all dataclass fields that have submodules
_submodule_dataclass_fields = tuple(
field.name for field in dataclasses.fields(self)
if field.name not in ('parent', 'name')
if field.name in self.__dict__ # ignore fields that have not been set
if _tree_has_modules(getattr(self, field.name))
)
object.__setattr__(self, '_submodule_dataclass_fields', _submodule_dataclass_fields)

# Initialization is deferred for top level Modules or any other "orphan"
# Modules until attachment by __setattr__ i.e. MyModule(..., parent=None)
if self.parent is None:
Expand Down Expand Up @@ -1047,7 +1080,7 @@ def _try_setup(self, shallow: bool = False) -> None:
# not call the user's setup. This avoids running before a
# transformation.
for field in dataclasses.fields(self):
if field.name != 'parent' and field.init:
if field.name not in ('parent', 'name') and field.init:
self._register_submodules(field.name, getattr(self, field.name))
if not shallow:
self.setup()
Expand Down Expand Up @@ -1095,19 +1128,76 @@ def _initialization_allowed(self):

def clone(self: M, *,
parent: Optional[Union[Scope, 'Module']] = None,
_deep_clone: Union[bool, weakref.WeakValueDictionary] = False,
**updates) -> M:
"""Creates a clone of this Module, with optionally updated arguments.
Args:
parent: The parent of the clone. The clone will have no parent if no
explicit parent is specified.
_deep_clone: A boolean or a weak value dictionary to control deep cloning
of submodules. If True, submodules will be cloned recursively. If a
weak value dictionary is passed, it will be used to cache cloned
submodules. This flag is used by init/apply/bind to avoid scope
leakage.
**updates: Attribute updates.
Returns:
A clone of the this Module with the updated attributes and parent.
"""
attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init}

attrs.update(parent=parent, **updates)
return self.__class__(**attrs)

# Here we implement deep cloning of submodules, this is necessary to avoid scope leakage
# from external submodules into init/apply/bind while preserving sharing-by-reference
# relationships between submodules.
if _deep_clone != False:
# We use a weak value dictionary to cache cloned submodules. When a shared
# submodule is cloned, its only cloned once else its fetched from the cache.
cache = weakref.WeakValueDictionary() if isinstance(_deep_clone, bool) else _deep_clone
def clone_fn(m: Module) -> Module:
key = m._id
if key in cache:
return cache[key]
else:
clone = m.clone(_deep_clone=cache)
cache[key] = clone
return clone

# _map_submodules will map over all submodules inside attrs
# value here can be any pytree, non-module values are ignored
for field_name, value in attrs.items():
attrs[field_name] = _map_submodules(clone_fn, value)

module = self.__class__(**attrs)

# We need to register submodules recursively after cloning to ensure that
# scopes for input submodules are eagerly created, this solves problems
# that arise when sharing external submodules.
if _deep_clone is True:
module._recursive_register_submodules()

return module

def _recursive_register_submodules(self):
"""Recursively registers submodules in this module and its children."""

# We only register submodules that are passed from the outside.
# These are found in the _submodule_dataclass_fields attribute.
for field_name in self._submodule_dataclass_fields:
value = self.__dict__[field_name]
current_in_setup = self._state.in_setup
try:
# We are temporarily setting in_setup to True to trick
# _register_submodules thinking its inside setup, else it will
# error. Maybe we can create a new state for this?
self._state.in_setup = True
self._register_submodules(field_name, value)
finally:
self._state.in_setup = current_in_setup

value = self.__dict__[field_name]
_map_submodules(lambda m: m._recursive_register_submodules(), value)

def variable(self, col: str, name: str,
init_fn: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -1323,7 +1413,7 @@ def __call__(self, x):

del args
scope = core.bind(variables, rngs=rngs, mutable=mutable)
return self.clone(parent=scope)
return self.clone(parent=scope, _deep_clone=True)

def unbind(self: M) -> Tuple[M, VariableDict]:
"""Returns an unbound copy of a Module and its variables.
Expand Down Expand Up @@ -2009,7 +2099,7 @@ def f(foo, x):
def scope_fn(scope, *args, **kwargs):
_context.capture_stack.append(capture_intermediates)
try:
return fn(module.clone(parent=scope), *args, **kwargs)
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
finally:
_context.capture_stack.pop()

Expand Down Expand Up @@ -2070,7 +2160,7 @@ def f(foo, x):
def scope_fn(scope, *args, **kwargs):
_context.capture_stack.append(capture_intermediates)
try:
return fn(module.clone(parent=scope), *args, **kwargs)
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
finally:
_context.capture_stack.pop()

Expand Down
191 changes: 190 additions & 1 deletion tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
import gc
import inspect
import operator
from typing import (Any, Callable, Generic, Mapping, NamedTuple, Sequence,
from typing import (Any, Callable, Generic, List, Mapping, NamedTuple, Sequence,
Tuple, TypeVar, get_type_hints)

from absl.testing import absltest
import pytest
from flax import config
from flax import errors
from flax import linen as nn
Expand Down Expand Up @@ -1994,6 +1995,194 @@ def __call__(self):
'Trying to access a property that'):
foo.apply({})

def test_basic_setup_test(self):
class Foo(nn.Module):
def setup(self):
self.b = 10

def __call__(self, x):
return self.b + x

module = Foo()
variables = module.bind({})

def test_nested_external_modules(self):
class Baz(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: 2)

def __call__(self, x):
return x + self.a * self.b

class Bar(nn.Module):
baz: Baz

def __call__(self, x):
return self.baz(x)

class Foo(nn.Module):
def setup(self):
self.bar = Bar(baz=Baz(a=1))

def __call__(self, x):
return self.bar.baz(x)

module = Foo()
y, variables = module.init_with_output(jax.random.PRNGKey(0), 1)
self.assertEqual(y, 3)

def test_getattribute_triggers_setup(self):
class B(nn.Module):
def setup(self):
self.p1 = self.param('p1', lambda k: jnp.ones((2,)))
def fn1(self, x):
return self.p1 + x
class A(nn.Module):
b: nn.Module
def __call__(self, x):
return self.b.fn1(x)
a = A(b=B())
k = random.PRNGKey(0)
x = jnp.zeros((2,))
vs = nn.init(lambda a,x: a(x), a)(k, x)
y = nn.apply(lambda a,x: a.b.fn1(x), a)(vs, x)
np.testing.assert_array_equal(y, jnp.ones((2,)))

def test_nested_sequential_in_call(self):
class Foo(nn.Module):
def setup(self):
self.seq = nn.Sequential([nn.Dense(10) for i in range(10)])

def __call__(self, x):
# try calling only the first layer
return self.seq.layers[0](x)


module = Foo()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 10)))

def test_setup_called_bounded_submodules(self):
module = nn.Sequential([
nn.Sequential([
nn.Dense(2),
nn.relu,
nn.Dense(2),
]),
nn.relu,
nn.Dense(2),
])
x = jnp.ones((1, 3))
variables = module.init(jax.random.PRNGKey(0), x)
bound_module = module.bind(variables)

self.assertIsNotNone(bound_module.layers[0].layers[0].scope)
self.assertIsNotNone(bound_module.layers[0].layers[2].scope)
self.assertIsNotNone(bound_module.layers[2].scope)

def test_call_bounded_toplevel_mutable(self):
class Bar(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: 1)

def __call__(self, x):
return x + self.a * self.b

class Foo(nn.Module):
bars: Sequence[Bar]

def __call__(self, x):
for bar in self.bars:
x = bar(x)
return x


module = Foo(bars=[])
module.bars = [Bar(a=1)]

variables = module.init(jax.random.PRNGKey(0), jnp.ones(()))
bound_module = module.bind(variables)

bar1 = bound_module.bars[0]
self.assertIsNotNone(bar1.scope)

# @pytest.mark.skip(reason="Leaving it here to tackle it later")
def test_nested_shared(self):
class Shared(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(1)(x)

class Unshared(nn.Module):
shared: nn.Module
def __call__(self, x):
return self.shared(x)

class Super(nn.Module):
a: nn.Module
b: nn.Module
def run_a(self, x):
return self.a(x)
def run_b(self, x):
return self.b(x)
def __call__(self, x):
return self.a(x) + self.b(x)


sh = Shared()
a = Unshared(shared=sh)
b = Unshared(shared=sh)
module = Super(a=a, b=b)

rng = jax.random.PRNGKey(0)
params = module.init(rng, jnp.ones(1))["params"]

module.apply({"params": params}, jnp.ones(1)) # works as expected
module.apply({"params": params}, jnp.ones(1), method="run_a") # works as expected
module.apply({"params": params}, jnp.ones(1), method="run_b") # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0"

def test_nested_init(self):
class Baz(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: jnp.ones(()))

def __call__(self, x):
return x + self.a * self.b

class Bar(nn.Module):
baz: Baz

def setup(self):
a = 1

def __call__(self, x):
return self.baz(x)

class Foo(nn.Module):

def setup(self):
self.bar: Bar = Bar(baz=Baz(a=1))

def __call__(self, x):
# y = self.bar(x)
y, bar_vars = self.bar.init_with_output(jax.random.PRNGKey(0), x)
return y, bar_vars

# create foo
module = Foo()

# run foo
(y, bar_vars), variables = module.init_with_output(
jax.random.PRNGKey(0), jnp.ones(()))

self.assertIn('params', bar_vars)


def test_repr(self):

class Base1(nn.Module):
Expand Down

0 comments on commit 1d67d4e

Please sign in to comment.