-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] Arrays are state #3791
[nnx] Arrays are state #3791
Conversation
f6a1846
to
f2e7de7
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you motivate this PR a bit more in the description? Why not require users to wrap their arrays into variables?
str_path = '/'.join((*path, key)) | ||
if value in ref_to_index: | ||
variables.append((key, ref_to_index[value])) | ||
if isinstance(value, Variable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe move this into a separate branch before is_state_leaf
to reduce the nesting here?
elif isinstance(value, Variable):
...
elif is_state_leaf(value):
...
@@ -746,23 +756,32 @@ def _graph_update_dynamic( | |||
|
|||
# case 2: subgraph is being updated | |||
if is_node(current_value): | |||
if isinstance(value, Variable): | |||
if is_state_leaf(value): | |||
raise ValueError( | |||
f'Expected a subgraph for {key!r}, but got a Variable: {value!r}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message needs updating.
raise ValueError( | ||
f'Expected a subgraph for {key!r}, but got a Variable: {value!r}' | ||
) | ||
assert not isinstance(value, Variable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this assert
only needed to please the type checker?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed
flax/experimental/nnx/nnx/state.py
Outdated
|
||
def __getattr__(self, key: Key) -> Variable | State: | ||
def __getattr__(self, key: Key) -> tp.Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the return type genuinely Any
here? It looks it could be StateLeaf | State
where
StateLeaf = Union[Variable, jax.Array, np.ndarray]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, coded this exact solution after implementing the TypeGuard :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...but hadn't seen you suggested exactly this
@@ -152,6 +153,9 @@ def is_node(x: tp.Any) -> bool: | |||
def is_node_type(x: type[tp.Any]) -> bool: | |||
return x in NODE_TYPES or x is PytreeType | |||
|
|||
def is_state_leaf(x: tp.Any) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be a typing.TypeGuard
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great idea! I should start using TypeGuard
more
@@ -152,6 +153,9 @@ def is_node(x: tp.Any) -> bool: | |||
def is_node_type(x: type[tp.Any]) -> bool: | |||
return x in NODE_TYPES or x is PytreeType | |||
|
|||
def is_state_leaf(x: tp.Any) -> bool: | |||
return isinstance(x, (Variable, np.ndarray, jax.Array)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I right that you're intentionally disallowing Python scalars here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main problem here is that Variable
is registered as a pytree so it would get accepted but we explicitly want to forbid it.
flax/experimental/nnx/nnx/state.py
Outdated
|
||
|
||
def _state_unflatten( | ||
static: tp.Tuple[Path, ...] | None, | ||
static: tp.Tuple[Path, ...], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the built in tuple
here and elsewhere.
4ef0143
to
13c0c3a
Compare
13c0c3a
to
8718110
Compare
What does this PR do?
jax.Array
andnp.ndarray
are now treated as state same asVariables
.Discussion
Previously we just treated all non-Variable and non-node as static, however, since we support pytrees and certain pytrees such as
opt_state
in theOptimizer
class from #3782 sometimes include arrays such as the trainingstep
array, then we should also support Array. This sadly breaks from the simplicity ofState
being just a nested mapping withVariables
as leaves, but on the other hand Array is a very natural state type to try to support and in fact the NNX already supported it in the past.