Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Arrays are state #3791

Merged
merged 1 commit into from
Mar 27, 2024
Merged

[nnx] Arrays are state #3791

merged 1 commit into from
Mar 27, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Mar 27, 2024

What does this PR do?

jax.Array and np.ndarray are now treated as state same as Variables.

Discussion

Previously we just treated all non-Variable and non-node as static, however, since we support pytrees and certain pytrees such as opt_state in the Optimizer class from #3782 sometimes include arrays such as the training step array, then we should also support Array. This sadly breaks from the simplicity of State being just a nested mapping with Variables as leaves, but on the other hand Array is a very natural state type to try to support and in fact the NNX already supported it in the past.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@superbobry superbobry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you motivate this PR a bit more in the description? Why not require users to wrap their arrays into variables?

str_path = '/'.join((*path, key))
if value in ref_to_index:
variables.append((key, ref_to_index[value]))
if isinstance(value, Variable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe move this into a separate branch before is_state_leaf to reduce the nesting here?

elif isinstance(value, Variable):
  ...
elif is_state_leaf(value):
  ...

@@ -746,23 +756,32 @@ def _graph_update_dynamic(

# case 2: subgraph is being updated
if is_node(current_value):
if isinstance(value, Variable):
if is_state_leaf(value):
raise ValueError(
f'Expected a subgraph for {key!r}, but got a Variable: {value!r}'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message needs updating.

raise ValueError(
f'Expected a subgraph for {key!r}, but got a Variable: {value!r}'
)
assert not isinstance(value, Variable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this assert only needed to please the type checker?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed


def __getattr__(self, key: Key) -> Variable | State:
def __getattr__(self, key: Key) -> tp.Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the return type genuinely Any here? It looks it could be StateLeaf | State where

StateLeaf = Union[Variable, jax.Array, np.ndarray]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, coded this exact solution after implementing the TypeGuard :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...but hadn't seen you suggested exactly this

@@ -152,6 +153,9 @@ def is_node(x: tp.Any) -> bool:
def is_node_type(x: type[tp.Any]) -> bool:
return x in NODE_TYPES or x is PytreeType

def is_state_leaf(x: tp.Any) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a typing.TypeGuard?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea! I should start using TypeGuard more

@@ -152,6 +153,9 @@ def is_node(x: tp.Any) -> bool:
def is_node_type(x: type[tp.Any]) -> bool:
return x in NODE_TYPES or x is PytreeType

def is_state_leaf(x: tp.Any) -> bool:
return isinstance(x, (Variable, np.ndarray, jax.Array))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right that you're intentionally disallowing Python scalars here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main problem here is that Variable is registered as a pytree so it would get accepted but we explicitly want to forbid it.



def _state_unflatten(
static: tp.Tuple[Path, ...] | None,
static: tp.Tuple[Path, ...],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the built in tuple here and elsewhere.

@cgarciae cgarciae force-pushed the nnx-arrays-are-state branch 3 times, most recently from 4ef0143 to 13c0c3a Compare March 27, 2024 15:15
@copybara-service copybara-service bot merged commit 0ab0365 into main Mar 27, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-arrays-are-state branch March 27, 2024 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants