Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] improve runtime flags #3607

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,21 @@ def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]):
vars(node).update(items)


def first_from(arg_name: str, *args: tp.Optional[A]) -> A:
"""Return the first non-None argument."""
def first_from(*args: tp.Optional[A], error_msg: str) -> A:
"""Return the first non-None argument.

If all arguments are None, raise a ValueError with the given error message.

Args:
*args: the arguments to check
error_msg: the error message to raise if all arguments are None
Returns:
The first non-None argument.
"""
for arg in args:
if arg is not None:
return arg
raise ValueError(f'No non-None arguments found for {arg_name!r}')
raise ValueError(error_msg)


def merge(
Expand Down
22 changes: 17 additions & 5 deletions flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flax.experimental import nnx
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.flaglib import flags
from flax.experimental.nnx.nnx import flaglib
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import initializers
from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype
Expand Down Expand Up @@ -305,7 +305,7 @@ def __init__(
bias_init: initializers.Initializer = initializers.zeros_init(),
use_bias: bool = True,
attention_fn: Callable[..., Array] = dot_product_attention,
decode: bool = False,
decode: bool | None = None,
normalize_qk: bool = False,
# Deprecated, will be removed.
qkv_dot_general: DotGeneralT | None = None,
Expand Down Expand Up @@ -402,6 +402,7 @@ def __call__(
dropout_rng: Optional[Array] = None,
rngs: rnglib.Rngs | None = None,
sow_weights: bool = False,
decode: bool | None = None,
):
...

Expand All @@ -416,6 +417,7 @@ def __call__(
dropout_rng: Array | None = None,
rngs: rnglib.Rngs | None = None,
sow_weights: bool = False,
decode: bool | None = None,
):
...

Expand All @@ -429,6 +431,7 @@ def __call__(
deterministic: bool | None = None,
rngs: rnglib.Rngs | None = None,
sow_weights: bool = False,
decode: bool | None = None,
):
"""Applies multi-head dot product attention on the input data.

Expand Down Expand Up @@ -490,7 +493,15 @@ def __call__(

# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.decode:
decode = first_from(
decode,
self.decode,
flaglib.flags.get('decode'),
error_msg="""No `decode` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)

if decode:
(
*batch_dims,
max_length,
Expand Down Expand Up @@ -530,10 +541,11 @@ def __call__(
self.dropout_rate > 0.0
): # Require `deterministic` only if using dropout.
deterministic = first_from(
'deterministic',
deterministic,
self.deterministic,
flags.get('deterministic'),
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)
if not deterministic:
if rngs is None:
Expand Down
3 changes: 2 additions & 1 deletion flax/experimental/nnx/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,11 @@ def __call__(
"""

use_running_average = first_from(
'use_running_average',
use_running_average,
self.use_running_average,
flaglib.flags.get('use_running_average'),
error_msg="""No `use_running_average` argument was provided to BatchNorm
as either a __call__ argument, class attribute, or nnx.flag.""",
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
Expand Down
3 changes: 2 additions & 1 deletion flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def __call__(
The masked inputs reweighted to preserve mean.
"""
deterministic = first_from(
'deterministic',
deterministic,
self.deterministic,
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to Dropout
as either a __call__ argument, class attribute, or nnx.flag.""",
)

if (self.rate == 0.0) or deterministic:
Expand Down
10 changes: 6 additions & 4 deletions flax/experimental/nnx/tests/nn/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_basic(self):
out_features=6,
rngs=nnx.Rngs(0),
)
y = module(jnp.ones((1, 7, 3)))
y = module(jnp.ones((1, 7, 3)), decode=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding another check/test to test instantiating MultiHeadAttention with decode=False as an init arg. It's already tested in the autoregressive test, but the main point of that test was to test jax.experimental.enable_x64

assert y.shape == (1, 7, 6)

def test_multihead_sow_attention_weights(self):
Expand Down Expand Up @@ -58,7 +58,8 @@ def __call__(self, x, sow_weights=False):
rng,
)

_ = module(x, True)
with nnx.flags(decode=False):
_ = module(x, True)
intermediates = module.pop(nnx.Intermediate)
assert intermediates['attention_layers/0/attention_weights'][0].shape == (
4,
Expand All @@ -74,7 +75,8 @@ def __call__(self, x, sow_weights=False):
6,
)

_ = module(x)
with nnx.flags(decode=False):
_ = module(x)
intermediates = module.pop(nnx.Intermediate)
assert not intermediates # empty

Expand All @@ -86,7 +88,7 @@ def test_autoregressive_decode_with_x64(self):
num_heads=2,
qkv_features=4,
decode=True,
rngs=nnx.Rngs(0)
rngs=nnx.Rngs(0),
)
module.init_cache(x.shape, dtype=x.dtype)
assert module.cached_key.shape == (1, 4, 2, 2)
Expand Down
Loading