Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FlatState a Mapping instead of a dict #3880

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions flax/experimental/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# limitations under the License.
from __future__ import annotations

from collections.abc import Mapping
import typing as tp
import typing_extensions as tpe

Expand All @@ -42,7 +43,7 @@
A = tp.TypeVar('A')

StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
FlatState = dict[PathParts, StateLeaf]
FlatState = Mapping[PathParts, StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
Expand All @@ -66,8 +67,8 @@ class State(tp.MutableMapping[Key, tp.Any], reprlib.Representable):
def __init__(
self,
mapping: tp.Union[
tp.Mapping[Key, tp.Mapping | StateLeaf],
tp.Iterator[tuple[Key, tp.Mapping | StateLeaf]],
Mapping[Key, Mapping | StateLeaf],
tp.Iterator[tuple[Key, Mapping | StateLeaf]],
],
/,
):
Expand Down
49 changes: 40 additions & 9 deletions flax/traverse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,19 @@
Traversals never mutate the original data. Therefore, an update essentially
returns a copy of the data including the provided updates.
"""
from __future__ import annotations

import abc
from collections.abc import Callable, Mapping
import copy
import dataclasses
import warnings
from typing import Any, Callable
from typing import Any, Union, overload

import jax

import flax
from flax.core.scope import VariableDict
from flax.typing import PathParts
from flax.typing import PathParts, VariableDict

from . import struct

Expand All @@ -77,7 +78,37 @@ class _EmptyNode:
empty_node = _EmptyNode()


def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None):
# TODO: In Python 3.10, use TypeAlias.
IsLeafCallable = Callable[[tuple[Any, ...], Mapping[Any, Any]], bool]


@overload
def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: Union[None, IsLeafCallable] = None,
sep: None = None
) -> dict[tuple[Any, ...], Any]:
...

@overload
def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: Union[None, IsLeafCallable] = None,
sep: str,
) -> dict[str, Any]:
...

def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: Union[None, IsLeafCallable] = None,
sep: Union[None, str] = None
) -> dict[Any, Any]:
"""Flatten a nested dictionary.

The nested keys are flattened to a tuple.
Expand Down Expand Up @@ -111,16 +142,16 @@ def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None):
The flattened dictionary.
"""
assert isinstance(
xs, (flax.core.FrozenDict, dict)
), f'expected (frozen)dict; got {type(xs)}'
xs, Mapping
), f'expected Mapping; got {type(xs).__qualname__}'

def _key(path):
def _key(path: tuple[Any, ...]) -> Union[tuple[Any, ...], str]:
if sep is None:
return path
return sep.join(path)

def _flatten(xs, prefix):
if not isinstance(xs, (flax.core.FrozenDict, dict)) or (
def _flatten(xs: Any, prefix: tuple[Any, ...]) -> dict[Any, Any]:
if not isinstance(xs, Mapping) or (
is_leaf and is_leaf(prefix, xs)
):
return {_key(prefix): xs}
Expand Down
Loading