Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Add LinearGeneral and MultiHeadAttention #3487

Merged
merged 2 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/lm1b/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def fill_unspecified_mesh_axes(
" parallelism axis. At most one axis can be unspecified."
)

determined_val = target_product / np.product(parallelism_vals) * -1
determined_val = target_product / np.prod(parallelism_vals) * -1

assert determined_val >= 1 and determined_val.is_integer, (
"Unspecified value unable to be determined with the given "
Expand All @@ -97,9 +97,9 @@ def fill_unspecified_mesh_axes(

target_type = "slices" if parallelism_type == "DCN" else "devices per slice"

assert np.product(parallelism_vals) == target_product, (
assert np.prod(parallelism_vals) == target_product, (
f"Number of {target_type} {target_product} does not match the product"
f" of the {parallelism_type} parallelism {np.product(parallelism_vals)}"
f" of the {parallelism_type} parallelism {np.prod(parallelism_vals)}"
)

return parallelism_vals
Expand Down
2 changes: 2 additions & 0 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@
from .nnx.nn.activations import standardize as standardize
from .nnx.nn.activations import swish as swish
from .nnx.nn.activations import tanh as tanh
from .nnx.nn.attention import MultiHeadAttention as MultiHeadAttention
from .nnx.nn.linear import Conv as Conv
from .nnx.nn.linear import Embed as Embed
from .nnx.nn.linear import Linear as Linear
from .nnx.nn.linear import LinearGeneral as LinearGeneral
from .nnx.nn.normalization import BatchNorm as BatchNorm
from .nnx.nn.normalization import LayerNorm as LayerNorm
from .nnx.nn.stochastic import Dropout as Dropout
Expand Down
12 changes: 5 additions & 7 deletions flax/experimental/nnx/nnx/flaglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@

@dataclasses.dataclass
class FlagsContext(threading.local):
flags_stack: tp.List[MappingProxyType[str, tp.Hashable]] = dataclasses.field(
flags_stack: tp.List[MappingProxyType[str, tp.Any]] = dataclasses.field(
default_factory=lambda: [MappingProxyType({})]
)


FLAGS_CONTEXT = FlagsContext()


class Flags(tp.Mapping[str, tp.Hashable]):
class Flags(tp.Mapping[str, tp.Any]):
__slots__ = ()

def __getitem__(self, name: str) -> tp.Hashable:
def __getitem__(self, name: str) -> tp.Any:
current_flags = FLAGS_CONTEXT.flags_stack[-1]
if name not in current_flags:
raise ValueError(f'Unknown Flag: {name}')
Expand All @@ -50,7 +50,7 @@ def __contains__(self, name: tp.Any) -> bool:
return name in FLAGS_CONTEXT.flags_stack[-1]

@contextmanager
def __call__(self, **kwargs: tp.Hashable):
def __call__(self, **kwargs: tp.Any):
current_flags = FLAGS_CONTEXT.flags_stack[-1]
FLAGS_CONTEXT.flags_stack.append(
MappingProxyType(dict(current_flags, **kwargs))
Expand All @@ -60,9 +60,7 @@ def __call__(self, **kwargs: tp.Hashable):
finally:
FLAGS_CONTEXT.flags_stack.pop()

def get(
self, name: str, default: tp.Hashable = None
) -> tp.Optional[tp.Hashable]:
def get(self, name: str, default: tp.Any = None) -> tp.Optional[tp.Any]:
return FLAGS_CONTEXT.flags_stack[-1].get(name, default)


Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/nnx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def __len__(self) -> int:


class Sequence(Module, tp.Generic[A]):
def __init__(self, iterable: tp.Iterable[A]):
def __init__(self, layers: tp.Iterable[A]):
i = 0
for i, value in enumerate(iterable):
for i, value in enumerate(layers):
setattr(self, str(i), value)
self._length = i + 1

Expand Down
7 changes: 2 additions & 5 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,12 @@ def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]):
vars(node).update(items)


# -------------------------
# utils
# -------------------------
def first_from(*args: tp.Optional[A]) -> A:
def first_from(arg_name: str, *args: tp.Optional[A]) -> A:
"""Return the first non-None argument."""
for arg in args:
if arg is not None:
return arg
raise ValueError('No non-None arguments found.')
raise ValueError(f'No non-None arguments found for {arg_name!r}')


def merge(
Expand Down
Loading
Loading