Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Fix State.__sub__ #3704

Merged
merged 1 commit into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions flax/experimental/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from flax import traverse_util
from flax.experimental.nnx.nnx import filterlib, reprlib
from flax.experimental.nnx.nnx.variables import Variable
from flax.typing import Path, Leaf
from flax.typing import Leaf, Path

A = tp.TypeVar('A')

Expand Down Expand Up @@ -180,7 +180,7 @@ def flat_state(self) -> dict[Key, Variable[Leaf]]:
return traverse_util.flatten_dict(self._mapping, sep='/') # type: ignore

@classmethod
def from_flat_path(cls, flat_state: FlatState) -> State:
def from_flat_path(cls, flat_state: FlatState, /) -> State:
nested_state = traverse_util.unflatten_dict(flat_state, sep='/')
return cls(nested_state)

Expand Down Expand Up @@ -274,8 +274,11 @@ def __sub__(self, other: 'State') -> 'State':
if not other:
return self

_mapping = {k: v for k, v in self._mapping.items() if k not in other}
return State(_mapping)
self_flat = self.flat_state()
other_flat = other.flat_state()
diff = {k: v for k, v in self_flat.items() if k not in other_flat}

return State.from_flat_path(diff)


def _state_flatten_with_keys(x: State):
Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/nnx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,8 @@ def scan_apply(
args,
is_leaf=lambda x: x is None,
)
broadcast_args = jax.tree_util.tree_map(
lambda axis, node: None if axis is not None else node,
broadcast_args = jax.tree_map(
lambda axis, node: node if axis is None else None,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
Expand Down
Loading