Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shape-based lazy init to LinenToNNX (prev LinenWrapper) #4081

Merged
merged 1 commit into from
Jul 22, 2024

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Jul 13, 2024

  • Renamed wrappers to LinenToNNX and NNXToLinen to minimize confusion.
  • Moved state initialization of LinenToNNX to __call__ to realize lazy init. This allows it to be a submodule of an NNX module, which doesn't have input args during initialization.
    • User can use nnx.shaped_init to do a dry run of __call__ and initialize the whole state & full graphdef.
  • Made state initialization of LinenToNNX nested & closer to NNX, aka. each VariableState is created for every jax Array, not every collection.

@IvyZX IvyZX requested a review from cgarciae July 13, 2024 00:27
@IvyZX IvyZX force-pushed the bridge branch 3 times, most recently from dd5936c to 2c8963d Compare July 15, 2024 23:32
@codecov-commenter
Copy link

codecov-commenter commented Jul 15, 2024

Codecov Report

Attention: Patch coverage is 0% with 51 lines in your changes missing coverage. Please review.

Project coverage is 0.00%. Comparing base (31adb00) to head (0139c90).
Report is 140 commits behind head on main.

Files Patch % Lines
flax/nnx/nnx/bridge/wrappers.py 0.00% 48 Missing ⚠️
flax/nnx/nnx/bridge/__init__.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@          Coverage Diff           @@
##            main   #4081    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files        106     108     +2     
  Lines      13582   14045   +463     
======================================
- Misses     13582   14045   +463     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 90 to 91
_rngs['params'] = _rngs['default']
del _rngs['default']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_rngs['params'] = _rngs['default']
del _rngs['default']
_rngs['params'] = _rngs.pop('default')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, we have to make Rngs implement MutableMapping for either of these to work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _rngs here is a dict instead of Rngs class, so this already works.

"""To trigger init of all `LinenToNNX` module variables and return a wholesome state."""
assert callable(module)
_ = module(*args, **kwargs)
return nnx.split(module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this function we should leverage the fact that Object._object__state._initializing still exists and set if via something like

def _set_initializing(initializing: bool):
  for _, value in graph.iter_graph(module):
    if isinstance(value, Object):
      value._object__state._initializing = initializing

and use the value of _initializing to choose between init and apply when calling the Linen Modules.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added _set_initializing to the LinenToNNX wrapper.
Note that we can't do check on top level modules' ._object__state._initializing because the top level module might be a pure NNX module with ._object__state._initializing always False.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd do something like this:

def _set_initializing(module, initializing: bool): 
  for _, value in graph.iter_graph(module): 
    if isinstance(value, Object): 
      value._object__state._initializing = initializing
      
def shaped_init(module: Module, *args, **kwargs):
  """To trigger init of all `LinenToNNX` module variables and return a wholesome state."""
  module = graph.clone(module) # create a copy
  _set_initializing(module, True)
  assert callable(module)
  try:
    _ = module(*args, **kwargs)
  finally:
    _set_initializing(module, False)
  return nnx.split(module)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, also renamed to lazy_init as discussed offline.

# Shape-based lazy init of the flax variables
if not rngs:
rngs = self.rngs
if not hasattr(self, 'states'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self._object__state.initializing instead, see above.

Suggested change
if not hasattr(self, 'states'):
if self._object__state.initializing:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

rngs = self.rngs
if self._object__state.initializing:
_rngs = (
{name: stream.key.raw_value for name, stream in rngs.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to generate new keys so Linen Modules get new RNG state every time.

Suggested change
{name: stream.key.raw_value for name, stream in rngs.items()}
{name: stream() for name, stream in rngs.items()}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@cgarciae cgarciae mentioned this pull request Jul 18, 2024
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')

variables = self.module.init(_rngs, *args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use init_with_output to avoid calling forward twice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@IvyZX IvyZX force-pushed the bridge branch 3 times, most recently from e9fa1e0 to 3f33ad8 Compare July 18, 2024 22:40
@copybara-service copybara-service bot merged commit d8bc194 into google:main Jul 22, 2024
15 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants