Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly setup input submodules #2884

Closed
wants to merge 1 commit into from

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Feb 17, 2023

What does this PR do?

Solves a couple of issues regarding the process of binding submodules from dataclass fields.

Current situation: when does X happen?

  • module gets a scope: when you use init/bind/apply on a Module .clone(parent=scope) is ran, during __post_init__ (which happens at the end of clone) there is a simple line that directly sets
self.scope = self.parent
  • submodules are bounded: this is a two step process:
    1. A process to bind a submodule of a bounded module must trigger, there are several of these:
      • When a submodule set to a non-dataclass field (__setattr__) of a bounded module, this mostly happens during setup:
        def setup(self): # self is bounded
          self.dense = nn.Dense(10) # triggers __setattr__
      • _try_setup is called, this happens either before a method is called (_call_wrapped_method) or when a non-dataclass field is accessed (__getattr__).
    2. The previous processes internally call _register_submodules on one or more submodules, __setattr__ calls it on the Module being set while _try_setup looks through all dataclass fields in addition to calling setup (which possibly triggers the previous case). _register_submodules clones the Modules being registered (while preserving sharing), then sets the registering Module as the parent of the Module being registered, and finally calls __post_init__ on the registered submodule which internally runs scope.push to create a scope for the registered child:
    self.scope = self.parent.scope.push(self.name, reuse=reuse_scopes)

Case 1: Situation with externally defined submodules

The situation above relies a lot in lazy behavior i.e. _try_setup is only called when a method is called or someone tries a yet undefined field that is lazily defined inside setup, because its initially undefined __getattr__ is triggered which calls _try_setup. There is a hole with this setup mechanism: submodules passed from the outside via dataclass fields do exist at all time. This means __getatrr__ is never triggered and nested submodules are possibly unbounded, here is a edge case that one would expect to work but doesn't in the current code base:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
  def setup(self):
    self.seq = nn.Sequential([nn.Dense(10) for i in range(10)])

  def __call__(self, x):
    # try calling only the first layer
    return self.seq.layers[0](x) # Error: Can't call compact methods on unbound modules
  

module = Foo()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 10)))

An identical situation would happen if you where to try to use bind instead:

module = nn.Sequential([nn.Dense(10) for i in range(10)])
x = jnp.ones((1, 10))
variables = module.init(jax.random.PRNGKey(0), x)
bounded_module = module.bind(variables)

# try calling only the first layer
module.layers[0](x) # Error: Can't call compact methods on unbound modules

Solution

The main problem with the scenarios described above is that __getattr__ only gets triggered if the requested attribute is not currently found on the object, therefore, when you access .layers in the examples above __getattr__ is not triggered because layers is present at construction time. So solve this, this PR proposes the implementation of the __getattribute__ method.

__getattribute__ is triggered when you access any attribute from the object, so we can use it to detect when dataclass attributes are being requested to solve the above problem. However, we have to be careful because __getattribute__ actually gets triggered A LOT (e.g. method, fields, class properties, etc), to minimize python overhead we create a _submodule_dataclass_fields during __post_init__ that contains the names of the fields that contain submodules within them and triggered _try_setup if only these are are being accessed.

Additional Fix: Deep cloning

When testing this change it was revealed that some internal code was now breaking because of this change, possibly because we where now binding more submodules that before and the existing code relied heavily on lazy binding. The culprit ended up being that we getting some scopes leaked into init/apply. The solution was to add a private _deep_clone: bool = False flag to the clone method that would clone all child submodules as well to get rid of any external scopes, this flag is only activated by init, apply, and bind.

Case 2: Naming race conditions with externally defined submodules

Laziness makes is very difficult to reason about certain situations and interestingly can trigger naming race conditions on some cases, here is an edge case where a shared Module can get assigned under different paths depending on which branch runs first:

class Shared(nn.Module):
  @nn.compact
  def __call__(self, x):
    return nn.Dense(1)(x)

class Unshared(nn.Module):
  shared: nn.Module
  def __call__(self, x):
    return self.shared(x)

class Super(nn.Module):
  a: nn.Module
  b: nn.Module

  def run_a(self, x):
    return self.a(x)

  def run_b(self, x):
    return self.b(x)'

  def __call__(self, x):
    return self.a(x) + self.b(x) # 'a' runs first, variables are initialized with an 'a' path

sh = Shared()
module = Super(a=Unshared(shared=sh), b=Unshared(shared=sh))

rng = jax.random.PRNGKey(0)
params = module.init(rng, jnp.ones(1))["params"]

module.apply({"params": params}, jnp.ones(1))  # works as expected
module.apply({"params": params}, jnp.ones(1), method="run_a")  # works because 'a' is present
# this case fails because self.b runs first but variables only contain 'a'
module.apply({"params": params}, jnp.ones(1), method="run_b")  # ScopeParamNotFoundError: Could not find 

So the problem is that even though all of Super's submodules are defined on construction, when a Super instance gets a scope during init/apply its submodules don't immediately get one, its up to runtime behavior to see which submodules get a scope and when.

Solution

The main idea to solve this is to recursively call _register_submodules as soon on all child Modules from dataclass fields as soon a parent Module gets a scope, this avoids any race condition since binding would always occur in the same order. Currently this is only being done inside at the end of clone when _deep_clone is True, this works for the reported case, but one might wonder this could be called more often e.g. at the end of __post_init__ if self.scope is not None.

On detail during the implementation is that to call _register_submodule without it complaining about being called outside of setup we had to trick it by temporarily setting _state.in_setup = True:

try:
  self._state.in_setup = True
  self._register_submodules(field_name, value)
finally:
  self._state.in_setup = current_in_setup

We can consider adding a new state that _register_submodules recognizes to avoid abusing the mechanism.

@cgarciae cgarciae changed the title fix mypy issue Properly setup input submodules Feb 17, 2023
@codecov-commenter
Copy link

codecov-commenter commented Feb 17, 2023

Codecov Report

Merging #2884 (bfb358b) into main (27c37ed) will increase coverage by 0.23%.
The diff coverage is 96.15%.

@@            Coverage Diff             @@
##             main    #2884      +/-   ##
==========================================
+ Coverage   81.45%   81.69%   +0.23%     
==========================================
  Files          55       55              
  Lines        5803     5855      +52     
==========================================
+ Hits         4727     4783      +56     
+ Misses       1076     1072       -4     
Impacted Files Coverage Δ
flax/linen/module.py 93.24% <96.15%> (+0.86%) ⬆️
flax/core/__init__.py 100.00% <0.00%> (ø)
flax/core/scope.py 89.80% <0.00%> (+0.23%) ⬆️
flax/core/frozen_dict.py 96.58% <0.00%> (+0.25%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@cgarciae cgarciae force-pushed the getattribute-v5 branch 2 times, most recently from 1d67d4e to eb4e03b Compare February 22, 2023 23:59
@cgarciae
Copy link
Collaborator Author

Closed in favor of #3077.

@cgarciae cgarciae closed this Jun 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants