Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align bridge variable tree structures #4194

Merged
merged 1 commit into from
Sep 17, 2024
Merged

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Sep 14, 2024

Fixing bridge API in a bunch of ways:

  1. ToNNX will convert the whole variables structure to NNX style. If your underlying Linen module has variable foo at collection bar, its ToNNX version will have an attribute foo with type bar, instead of an attribute bar with a dict {'foo': ...}.
    This means you can freely put ToNNX in the top or middle or back of the whole model layer stack, and the weight pytree structure shouldn't change.
    Same goes for ToLinen - if your top-level type is Linen, the whole variable tree shall be Linen-style.

  2. If you have a vanilla nnx.Variable with no sharding metadata, hooks, etc, ToLinen will not convert it into an NNXMeta, but instead just keep the vanilla JAX array inside. This makes it more intuitive and pytree-structure-proof for any Linen users not using partitioning metadata.

  3. nn.get_partition_spec now works on NNXMeta wrappers, and any other wrapper that has get_partition_spec method.

  4. Updated the nnx.bridge guide accordingly.

@IvyZX IvyZX requested review from levskaya and cgarciae September 14, 2024 00:33
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@IvyZX IvyZX force-pushed the bdg-tree branch 3 times, most recently from 8eceb3c to 172e832 Compare September 14, 2024 00:55
import dataclasses
import typing as tp
from typing import Any

from flax import nnx
from flax import linen
from flax import traverse_util
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: flax.nnx.traversals has similar same APIs but with more accurate typing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Adopted.

@@ -85,6 +87,50 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
return fn


def _recursive_merge(dict1, dict2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be expressed using traversals / traverse_util as:

flat_map = traversal.flatten_mapping(dict1)
flat_map |= traversal.flatten_mapping(dict2)
return traversal.unflatten_mapping(flat_map)

Not sure if there are edge cases though but it seems easy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Adopted.

dict1[key] = _recursive_merge(dict1[key], value)
else:
# Merge non-dictionary values
dict1[key] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mutating dict1 always safe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using nnx.traversals is indeed easier!

def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
linen_structured = {}
for kp, v in traverse_util.flatten_dict(
nnx_attrs, is_leaf=lambda _, x: isinstance(x, nnx.Variable | nnx.GraphDef)).items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know Unions were now valid in isinstance, nice!

@IvyZX IvyZX force-pushed the bdg-tree branch 2 times, most recently from 55cbdef to cac253e Compare September 16, 2024 23:50
@copybara-service copybara-service bot merged commit ddaef57 into google:main Sep 17, 2024
17 checks passed
@IvyZX IvyZX deleted the bdg-tree branch September 17, 2024 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants