Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eagerly bind submodules #3077

Merged
merged 1 commit into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,10 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]:

return method_or_fn

def _map_submodules(fn: Callable[['Module'], Any], tree):
"""Map a function over all submodules in a tree."""
g = lambda _, x: fn(x) if isinstance(x, Module) else x
return _freeze_attr(_map_over_modules_in_tree(g, tree))

class SetupState(enum.IntEnum):
# setup() has not been called.
Expand Down Expand Up @@ -1000,6 +1004,12 @@ def __post_init__(self) -> None:
else:
raise ValueError('parent must be None, Module or Scope')

# eagerly bind submodules if scope is available
if self.scope is not None:
for field in dataclasses.fields(self):
if field.name not in ('parent', 'name') and field.init:
self._register_submodules(field.name, getattr(self, field.name))

self._state.is_initialized = True

def __repr__(self) -> str:
Expand Down Expand Up @@ -1048,8 +1058,8 @@ def _register_submodules(self, name, val):
_caches[root] = cache
queue = []
preserve_adopted_names = config.flax_preserve_adopted_names
if hasattr(self, 'preserve_adopted_names'):
preserve_adopted_names = self.preserve_adopted_names
if hasattr(type(self), 'preserve_adopted_names'):
preserve_adopted_names = type(self).preserve_adopted_names
def adopt_attr_modules(cache, queue, suffix, subvalue):
if isinstance(subvalue, Module):
adopted_name = None
Expand Down Expand Up @@ -1092,7 +1102,7 @@ def _try_setup(self, shallow: bool = False) -> None:
# not call the user's setup. This avoids running before a
# transformation.
for field in dataclasses.fields(self):
if field.name != 'parent' and field.init:
if field.name not in ('parent', 'name') and field.init:
self._register_submodules(field.name, getattr(self, field.name))
if not shallow:
self.setup()
Expand Down Expand Up @@ -1137,23 +1147,61 @@ def _name_taken(self,

@property
def _initialization_allowed(self):
return self._state.in_setup or self._state.in_compact_method
return (not self._state.is_initialized # allow eager attachment in post-init
or self._state.in_setup
or self._state.in_compact_method)

def clone(self: M, *,
parent: Optional[Union[Scope, 'Module']] = None,
_deep_clone: Union[bool, weakref.WeakValueDictionary] = False,
**updates) -> M:
"""Creates a clone of this Module, with optionally updated arguments.

Args:
parent: The parent of the clone. The clone will have no parent if no
explicit parent is specified.
_deep_clone: A boolean or a weak value dictionary to control deep cloning
of submodules. If True, submodules will be cloned recursively. If a
weak value dictionary is passed, it will be used to cache cloned
submodules. This flag is used by init/apply/bind to avoid scope
leakage.
**updates: Attribute updates.
Returns:
A clone of the this Module with the updated attributes and parent.
"""
attrs = {f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init}

attrs.update(parent=parent, **updates)
return self.__class__(**attrs)

# Here we implement deep cloning of submodules, this is necessary to avoid scope leakage
# from external submodules into init/apply/bind while preserving sharing-by-reference
# relationships between submodules.
if _deep_clone != False:
# We use a weak value dictionary to cache cloned submodules. When a shared
# submodule is cloned, its only cloned once else its fetched from the cache.
cache = weakref.WeakValueDictionary() if isinstance(_deep_clone, bool) else _deep_clone
def clone_fn(m: Module) -> Module:
if hasattr(m, '_id'):
key = m._id
if key in cache:
return cache[key]
else:
clone = m.clone(_deep_clone=cache)
cache[key] = clone
return clone
else:
# If the module doesn't have an _id attribute it could be a mock object
# so we return it as is.
return m

# _map_submodules will map over all submodules inside attrs
# value here can be any pytree, non-module values are ignored
for field_name, value in attrs.items():
attrs[field_name] = _map_submodules(clone_fn, value)

module = self.__class__(**attrs)

return module

def variable(self, col: str, name: str,
init_fn: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -1371,7 +1419,7 @@ def __call__(self, x):

del args
scope = core.bind(variables, rngs=rngs, mutable=mutable)
return self.clone(parent=scope)
return self.clone(parent=scope, _deep_clone=True)

def unbind(self: M) -> Tuple[M, VariableDict]:
"""Returns an unbound copy of a Module and its variables.
Expand Down Expand Up @@ -2057,7 +2105,7 @@ def f(foo, x):
def scope_fn(scope, *args, **kwargs):
_context.capture_stack.append(capture_intermediates)
try:
return fn(module.clone(parent=scope), *args, **kwargs)
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
finally:
_context.capture_stack.pop()

Expand Down Expand Up @@ -2118,7 +2166,7 @@ def f(foo, x):
def scope_fn(scope, *args, **kwargs):
_context.capture_stack.append(capture_intermediates)
try:
return fn(module.clone(parent=scope), *args, **kwargs)
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
finally:
_context.capture_stack.pop()

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ filterwarnings = [
"ignore:file system plugins are not loaded.*:UserWarning",
"ignore:unable to load libtensorflow_io_plugins.so.*:UserWarning",
# Remove this after next Optax release after 3/27/2023
"ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning"
"ignore:jax.numpy.DeviceArray is deprecated. Use jax.Array.*:DeprecationWarning",
# DeprecationWarning: pkg_resources is deprecated as an API
"ignore:.*pkg_resources is deprecated.*:DeprecationWarning",
# DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
"ignore:.*Deprecated call to `pkg_resources.declare_namespace.*:DeprecationWarning",
]

[tool.coverage.report]
Expand Down
174 changes: 174 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,180 @@ def __call__(self):
'Trying to access a property that'):
foo.apply({})

def test_nested_external_modules(self):
class Baz(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: 2)

def __call__(self, x):
return x + self.a * self.b

class Bar(nn.Module):
baz: Baz

def __call__(self, x):
return self.baz(x)

class Foo(nn.Module):
def setup(self):
self.bar = Bar(baz=Baz(a=1))

def __call__(self, x):
return self.bar.baz(x)

module = Foo()
y, variables = module.init_with_output(jax.random.PRNGKey(0), 1)
self.assertEqual(y, 3)

def test_getattribute_triggers_setup(self):
class B(nn.Module):
def setup(self):
self.p1 = self.param('p1', lambda k: jnp.ones((2,)))
def fn1(self, x):
return self.p1 + x
class A(nn.Module):
b: nn.Module
def __call__(self, x):
return self.b.fn1(x)
a = A(b=B())
k = random.PRNGKey(0)
x = jnp.zeros((2,))
vs = nn.init(lambda a,x: a(x), a)(k, x)
y = nn.apply(lambda a,x: a.b.fn1(x), a)(vs, x)
np.testing.assert_array_equal(y, jnp.ones((2,)))

def test_nested_sequential_in_call(self):
class Foo(nn.Module):
def setup(self):
self.seq = nn.Sequential([nn.Dense(10) for i in range(10)])

def __call__(self, x):
# try calling only the first layer
return self.seq.layers[0](x)

module = Foo()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 10)))

def test_setup_called_bounded_submodules(self):
module = nn.Sequential([
nn.Sequential([
nn.Dense(2),
nn.relu,
nn.Dense(2),
]),
nn.relu,
nn.Dense(2),
])
x = jnp.ones((1, 3))
variables = module.init(jax.random.PRNGKey(0), x)
bound_module = module.bind(variables)

self.assertIsNotNone(bound_module.layers[0].layers[0].scope)
self.assertIsNotNone(bound_module.layers[0].layers[2].scope)
self.assertIsNotNone(bound_module.layers[2].scope)

def test_call_bounded_toplevel_mutable(self):
class Bar(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: 1)

def __call__(self, x):
return x + self.a * self.b

class Foo(nn.Module):
bars: Sequence[Bar]

def __call__(self, x):
for bar in self.bars:
x = bar(x)
return x


module = Foo(bars=[])
module.bars = [Bar(a=1)]

variables = module.init(jax.random.PRNGKey(0), jnp.ones(()))
bound_module = module.bind(variables)

bar1 = bound_module.bars[0]
self.assertIsNotNone(bar1.scope)

def test_nested_init(self):
class Baz(nn.Module):
a: int

def setup(self):
self.b = self.param('b', lambda k: jnp.ones(()))

def __call__(self, x):
return x + self.a * self.b

class Bar(nn.Module):
baz: Baz

def setup(self):
a = 1

def __call__(self, x):
return self.baz(x)

class Foo(nn.Module):

def setup(self):
self.bar: Bar = Bar(baz=Baz(a=1))

def __call__(self, x):
# y = self.bar(x)
y, bar_vars = self.bar.init_with_output(jax.random.PRNGKey(0), x)
return y, bar_vars

# create foo
module = Foo()

# run foo
(y, bar_vars), variables = module.init_with_output(
jax.random.PRNGKey(0), jnp.ones(()))

self.assertIn('params', bar_vars)

def test_nested_shared(self):
class Shared(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(1)(x)

class Unshared(nn.Module):
shared: nn.Module
def __call__(self, x):
return self.shared(x)

class Super(nn.Module):
a: nn.Module
b: nn.Module
def run_a(self, x):
return self.a(x)
def run_b(self, x):
return self.b(x)
def __call__(self, x):
return self.a(x) + self.b(x)


sh = Shared()
a = Unshared(shared=sh)
b = Unshared(shared=sh)
module = Super(a=a, b=b)

rng = jax.random.PRNGKey(0)
params = module.init(rng, jnp.ones(1))["params"]

module.apply({"params": params}, jnp.ones(1)) # works as expected
module.apply({"params": params}, jnp.ones(1), method="run_a") # works as expected
module.apply({"params": params}, jnp.ones(1), method="run_b") # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0"

def test_repr(self):

class Base1(nn.Module):
Expand Down