Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Solves a couple of issues regarding the process of binding submodules from dataclass fields.
Current situation: when does X happen?
init
/bind
/apply
on a Module.clone(parent=scope)
is ran, during__post_init__
(which happens at the end ofclone
) there is a simple line that directly sets__setattr__
) of a bounded module, this mostly happens duringsetup
:_try_setup
is called, this happens either before a method is called (_call_wrapped_method
) or when a non-dataclass field is accessed (__getattr__
)._register_submodules
on one or more submodules,__setattr__
calls it on the Module being set while_try_setup
looks through all dataclass fields in addition to callingsetup
(which possibly triggers the previous case)._register_submodules
clones the Modules being registered (while preserving sharing), then sets the registering Module as theparent
of the Module being registered, and finally calls__post_init__
on the registered submodule which internally runsscope.push
to create a scope for the registered child:Case 1: Situation with externally defined submodules
The situation above relies a lot in lazy behavior i.e.
_try_setup
is only called when a method is called or someone tries a yet undefined field that is lazily defined insidesetup
, because its initially undefined__getattr__
is triggered which calls_try_setup
. There is a hole with this setup mechanism: submodules passed from the outside via dataclass fields do exist at all time. This means__getatrr__
is never triggered and nested submodules are possibly unbounded, here is a edge case that one would expect to work but doesn't in the current code base:An identical situation would happen if you where to try to use
bind
instead:Solution
The main problem with the scenarios described above is that
__getattr__
only gets triggered if the requested attribute is not currently found on the object, therefore, when you access.layers
in the examples above__getattr__
is not triggered becauselayers
is present at construction time. So solve this, this PR proposes the implementation of the__getattribute__
method.__getattribute__
is triggered when you access any attribute from the object, so we can use it to detect when dataclass attributes are being requested to solve the above problem. However, we have to be careful because__getattribute__
actually gets triggered A LOT (e.g. method, fields, class properties, etc), to minimize python overhead we create a_submodule_dataclass_fields
during__post_init__
that contains the names of the fields that contain submodules within them and triggered_try_setup
if only these are are being accessed.Additional Fix: Deep cloning
When testing this change it was revealed that some internal code was now breaking because of this change, possibly because we where now binding more submodules that before and the existing code relied heavily on lazy binding. The culprit ended up being that we getting some scopes leaked into
init
/apply
. The solution was to add a private_deep_clone: bool = False
flag to theclone
method that would clone all child submodules as well to get rid of any external scopes, this flag is only activated byinit
,apply
, andbind
.Case 2: Naming race conditions with externally defined submodules
Laziness makes is very difficult to reason about certain situations and interestingly can trigger naming race conditions on some cases, here is an edge case where a shared Module can get assigned under different paths depending on which branch runs first:
So the problem is that even though all of
Super
's submodules are defined on construction, when aSuper
instance gets a scope duringinit
/apply
its submodules don't immediately get one, its up to runtime behavior to see which submodules get a scope and when.Solution
The main idea to solve this is to recursively call
_register_submodules
as soon on all child Modules from dataclass fields as soon a parent Module gets a scope, this avoids any race condition since binding would always occur in the same order. Currently this is only being done inside at the end ofclone
when_deep_clone is True
, this works for the reported case, but one might wonder this could be called more often e.g. at the end of__post_init__
ifself.scope is not None
.On detail during the implementation is that to call
_register_submodule
without it complaining about being called outside ofsetup
we had to trick it by temporarily setting_state.in_setup = True
:We can consider adding a new state that
_register_submodules
recognizes to avoid abusing the mechanism.