Skip to content

Commit

Permalink
eagerly bind submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed May 19, 2023
1 parent 63df54a commit 0a2ae27
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 6 deletions.
64 changes: 59 additions & 5 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,13 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]:

return method_or_fn

def _tree_has_modules(x):
return any(isinstance(v, Module) for v in jax.tree_util.tree_leaves(x))

def _map_submodules(fn: Callable[['Module'], Any], tree):
"""Map a function over all submodules in a tree."""
g = lambda _, x: fn(x) if isinstance(x, Module) else x
return _freeze_attr(_map_over_modules_in_tree(g, tree))

class SetupState(enum.IntEnum):
# setup() has not been called.
Expand Down Expand Up @@ -1000,6 +1007,17 @@ def __post_init__(self) -> None:
else:
raise ValueError('parent must be None, Module or Scope')

# eagerly bind submodules if scope is available
if self.scope is not None:
for field in dataclasses.fields(self):
if (
field.name not in ('parent', 'name')
and field.name in self.__dict__ # ignore fields that have not been set
):
value = getattr(self, field.name)
if _tree_has_modules(value):
self._register_submodules(field.name, value)

self._state.is_initialized = True

def __repr__(self) -> str:
Expand Down Expand Up @@ -1092,7 +1110,7 @@ def _try_setup(self, shallow: bool = False) -> None:
# not call the user's setup. This avoids running before a
# transformation.
for field in dataclasses.fields(self):
if field.name != 'parent' and field.init:
if field.name not in ('parent', 'name') and field.init:
self._register_submodules(field.name, getattr(self, field.name))
if not shallow:
self.setup()
Expand Down Expand Up @@ -1141,19 +1159,55 @@ def _initialization_allowed(self):

def clone(self: M, *,
parent: Optional[Union[Scope, 'Module']] = None,
_deep_clone: Union[bool, weakref.WeakValueDictionary] = False,
**updates) -> M:
"""Creates a clone of this Module, with optionally updated arguments.
Args:
parent: The parent of the clone. The clone will have no parent if no
explicit parent is specified.
_deep_clone: A boolean or a weak value dictionary to control deep cloning
of submodules. If True, submodules will be cloned recursively. If a
weak value dictionary is passed, it will be used to cache cloned
submodules. This flag is used by init/apply/bind to avoid scope
leakage.
**updates: Attribute updates.
Returns:
A clone of the this Module with the updated attributes and parent.
"""
attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init}

attrs.update(parent=parent, **updates)
return self.__class__(**attrs)

# Here we implement deep cloning of submodules, this is necessary to avoid scope leakage
# from external submodules into init/apply/bind while preserving sharing-by-reference
# relationships between submodules.
if _deep_clone != False:
# We use a weak value dictionary to cache cloned submodules. When a shared
# submodule is cloned, its only cloned once else its fetched from the cache.
cache = weakref.WeakValueDictionary() if isinstance(_deep_clone, bool) else _deep_clone
def clone_fn(m: Module) -> Module:
if hasattr(m, '_id'):
key = m._id
if key in cache:
return cache[key]
else:
clone = m.clone(_deep_clone=cache)
cache[key] = clone
return clone
else:
# If the module doesn't have an _id attribute it could be a mock object
# so we return it as is.
return m

# _map_submodules will map over all submodules inside attrs
# value here can be any pytree, non-module values are ignored
for field_name, value in attrs.items():
attrs[field_name] = _map_submodules(clone_fn, value)

module = self.__class__(**attrs)

return module

def variable(self, col: str, name: str,
init_fn: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -1371,7 +1425,7 @@ def __call__(self, x):

del args
scope = core.bind(variables, rngs=rngs, mutable=mutable)
return self.clone(parent=scope)
return self.clone(parent=scope, _deep_clone=True)

def unbind(self: M) -> Tuple[M, VariableDict]:
"""Returns an unbound copy of a Module and its variables.
Expand Down Expand Up @@ -2057,7 +2111,7 @@ def f(foo, x):
def scope_fn(scope, *args, **kwargs):
_context.capture_stack.append(capture_intermediates)
try:
return fn(module.clone(parent=scope), *args, **kwargs)
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
finally:
_context.capture_stack.pop()

Expand Down Expand Up @@ -2118,7 +2172,7 @@ def f(foo, x):
def scope_fn(scope, *args, **kwargs):
_context.capture_stack.append(capture_intermediates)
try:
return fn(module.clone(parent=scope), *args, **kwargs)
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
finally:
_context.capture_stack.pop()

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ filterwarnings = [
"ignore:file system plugins are not loaded.*:UserWarning",
"ignore:unable to load libtensorflow_io_plugins.so.*:UserWarning",
# Remove this after next Optax release after 3/27/2023
"ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning"
"ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning",
# DeprecationWarning: pkg_resources is deprecated as an API
"ignore:.*pkg_resources is deprecated.*:DeprecationWarning",
# DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
"ignore:.*Deprecated call to `pkg_resources.declare_namespace.*:DeprecationWarning",
]

[tool.coverage.report]
Expand Down
174 changes: 174 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,180 @@ def __call__(self):
'Trying to access a property that'):
foo.apply({})

def test_nested_external_modules(self):
class Baz(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: 2)

def __call__(self, x):
return x + self.a * self.b

class Bar(nn.Module):
baz: Baz

def __call__(self, x):
return self.baz(x)

class Foo(nn.Module):
def setup(self):
self.bar = Bar(baz=Baz(a=1))

def __call__(self, x):
return self.bar.baz(x)

module = Foo()
y, variables = module.init_with_output(jax.random.PRNGKey(0), 1)
self.assertEqual(y, 3)

def test_getattribute_triggers_setup(self):
class B(nn.Module):
def setup(self):
self.p1 = self.param('p1', lambda k: jnp.ones((2,)))
def fn1(self, x):
return self.p1 + x
class A(nn.Module):
b: nn.Module
def __call__(self, x):
return self.b.fn1(x)
a = A(b=B())
k = random.PRNGKey(0)
x = jnp.zeros((2,))
vs = nn.init(lambda a,x: a(x), a)(k, x)
y = nn.apply(lambda a,x: a.b.fn1(x), a)(vs, x)
np.testing.assert_array_equal(y, jnp.ones((2,)))

def test_nested_sequential_in_call(self):
class Foo(nn.Module):
def setup(self):
self.seq = nn.Sequential([nn.Dense(10) for i in range(10)])

def __call__(self, x):
# try calling only the first layer
return self.seq.layers[0](x)

module = Foo()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 10)))

def test_setup_called_bounded_submodules(self):
module = nn.Sequential([
nn.Sequential([
nn.Dense(2),
nn.relu,
nn.Dense(2),
]),
nn.relu,
nn.Dense(2),
])
x = jnp.ones((1, 3))
variables = module.init(jax.random.PRNGKey(0), x)
bound_module = module.bind(variables)

self.assertIsNotNone(bound_module.layers[0].layers[0].scope)
self.assertIsNotNone(bound_module.layers[0].layers[2].scope)
self.assertIsNotNone(bound_module.layers[2].scope)

def test_call_bounded_toplevel_mutable(self):
class Bar(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: 1)

def __call__(self, x):
return x + self.a * self.b

class Foo(nn.Module):
bars: Sequence[Bar]

def __call__(self, x):
for bar in self.bars:
x = bar(x)
return x


module = Foo(bars=[])
module.bars = [Bar(a=1)]

variables = module.init(jax.random.PRNGKey(0), jnp.ones(()))
bound_module = module.bind(variables)

bar1 = bound_module.bars[0]
self.assertIsNotNone(bar1.scope)

def test_nested_init(self):
class Baz(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: jnp.ones(()))

def __call__(self, x):
return x + self.a * self.b

class Bar(nn.Module):
baz: Baz

def setup(self):
a = 1

def __call__(self, x):
return self.baz(x)

class Foo(nn.Module):

def setup(self):
self.bar: Bar = Bar(baz=Baz(a=1))

def __call__(self, x):
# y = self.bar(x)
y, bar_vars = self.bar.init_with_output(jax.random.PRNGKey(0), x)
return y, bar_vars

# create foo
module = Foo()

# run foo
(y, bar_vars), variables = module.init_with_output(
jax.random.PRNGKey(0), jnp.ones(()))

self.assertIn('params', bar_vars)

def test_nested_shared(self):
class Shared(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(1)(x)

class Unshared(nn.Module):
shared: nn.Module
def __call__(self, x):
return self.shared(x)

class Super(nn.Module):
a: nn.Module
b: nn.Module
def run_a(self, x):
return self.a(x)
def run_b(self, x):
return self.b(x)
def __call__(self, x):
return self.a(x) + self.b(x)


sh = Shared()
a = Unshared(shared=sh)
b = Unshared(shared=sh)
module = Super(a=a, b=b)

rng = jax.random.PRNGKey(0)
params = module.init(rng, jnp.ones(1))["params"]

module.apply({"params": params}, jnp.ones(1)) # works as expected
module.apply({"params": params}, jnp.ones(1), method="run_a") # works as expected
module.apply({"params": params}, jnp.ones(1), method="run_b") # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0"

def test_repr(self):

class Base1(nn.Module):
Expand Down

0 comments on commit 0a2ae27

Please sign in to comment.