Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up axis hooks in nnx.Variable #4189

Merged
merged 1 commit into from
Sep 11, 2024
Merged

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Sep 11, 2024

  • Swapped the order of axis index and name in hooks. This is because in pure JAX, an add/remove axis always needs an index, but perhaps doesn't need a name.

    • This will allow hooks to be triggered on every lifted transform, even if sharding annotations is not part of the model.
  • Simplified the boilderplate code of creating hooks in nnx.Variable.

  • Add a simple test on add/remove axis.

@IvyZX IvyZX requested a review from cgarciae September 11, 2024 00:43
@cgarciae
Copy link
Collaborator

The PR looks good, swapping the arguments makes sense.

We don't have to do this here, but I was wondering if we could conditionally add the hooks if at least one hook is present to avoid adding useless metadata e.g:

if get_value_hooks:
  self.get_value_hooks = get_value_hooks

and check for existence on usage e.g.

if hasattr(self, 'get_value_hooks'):
  for hook in self.get_value_hooks:
    value = hook(self, value)

@IvyZX
Copy link
Collaborator Author

IvyZX commented Sep 11, 2024

The PR looks good, swapping the arguments makes sense.

We don't have to do this here, but I was wondering if we could conditionally add the hooks if at least one hook is present to avoid adding useless metadata e.g:

if get_value_hooks:
  self.get_value_hooks = get_value_hooks

and check for existence on usage e.g.

if hasattr(self, 'get_value_hooks'):
  for hook in self.get_value_hooks:
    value = hook(self, value)

You know what, I tried this, but it ended up with infinite recursion because hasattr(self, 'get_value_hooks') calls __getattr__ which calls hasattr(self, 'get_value_hooks') again... We are Pythonic indeed

We can make get_value_hooks and set_value_hooks always available, but that creates diff between them and other hooks so I'm not so sure

@cgarciae
Copy link
Collaborator

Ah yes I've faced this, we need to we should fix it by raising an AttributeError manually inside __getattr__.

@copybara-service copybara-service bot merged commit 61ea8a6 into google:main Sep 11, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants