Skip to content

Commit

Permalink
[nnx] improve runtime flags
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 8, 2024
1 parent 7cbc3c1 commit 5546627
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 14 deletions.
15 changes: 12 additions & 3 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,21 @@ def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]):
vars(node).update(items)


def first_from(arg_name: str, *args: tp.Optional[A]) -> A:
"""Return the first non-None argument."""
def first_from(*args: tp.Optional[A], error_msg: str) -> A:
"""Return the first non-None argument.
If all arguments are None, raise a ValueError with the given error message.
Args:
*args: the arguments to check
error_msg: the error message to raise if all arguments are None
Returns:
The first non-None argument.
"""
for arg in args:
if arg is not None:
return arg
raise ValueError(f'No non-None arguments found for {arg_name!r}')
raise ValueError(error_msg)


def merge(
Expand Down
22 changes: 17 additions & 5 deletions flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flax.experimental import nnx
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.flaglib import flags
from flax.experimental.nnx.nnx import flaglib
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import initializers
from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype
Expand Down Expand Up @@ -305,7 +305,7 @@ def __init__(
bias_init: initializers.Initializer = initializers.zeros_init(),
use_bias: bool = True,
attention_fn: Callable[..., Array] = dot_product_attention,
decode: bool = False,
decode: bool | None = None,
normalize_qk: bool = False,
# Deprecated, will be removed.
qkv_dot_general: DotGeneralT | None = None,
Expand Down Expand Up @@ -402,6 +402,7 @@ def __call__(
dropout_rng: Optional[Array] = None,
rngs: rnglib.Rngs | None = None,
sow_weights: bool = False,
decode: bool | None = None,
):
...

Expand All @@ -416,6 +417,7 @@ def __call__(
dropout_rng: Array | None = None,
rngs: rnglib.Rngs | None = None,
sow_weights: bool = False,
decode: bool | None = None,
):
...

Expand All @@ -429,6 +431,7 @@ def __call__(
deterministic: bool | None = None,
rngs: rnglib.Rngs | None = None,
sow_weights: bool = False,
decode: bool | None = None,
):
"""Applies multi-head dot product attention on the input data.
Expand Down Expand Up @@ -490,7 +493,15 @@ def __call__(

# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.decode:
decode = first_from(
decode,
self.decode,
flaglib.flags.get('decode'),
error_msg="""No `decode` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)

if decode:
(
*batch_dims,
max_length,
Expand Down Expand Up @@ -530,10 +541,11 @@ def __call__(
self.dropout_rate > 0.0
): # Require `deterministic` only if using dropout.
deterministic = first_from(
'deterministic',
deterministic,
self.deterministic,
flags.get('deterministic'),
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)
if not deterministic:
if rngs is None:
Expand Down
3 changes: 2 additions & 1 deletion flax/experimental/nnx/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,11 @@ def __call__(
"""

use_running_average = first_from(
'use_running_average',
use_running_average,
self.use_running_average,
flaglib.flags.get('use_running_average'),
error_msg="""No `use_running_average` argument was provided to BatchNorm
as either a __call__ argument, class attribute, or nnx.flag.""",
)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
Expand Down
3 changes: 2 additions & 1 deletion flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def __call__(
The masked inputs reweighted to preserve mean.
"""
deterministic = first_from(
'deterministic',
deterministic,
self.deterministic,
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to Dropout
as either a __call__ argument, class attribute, or nnx.flag.""",
)

if (self.rate == 0.0) or deterministic:
Expand Down
10 changes: 6 additions & 4 deletions flax/experimental/nnx/tests/nn/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_basic(self):
out_features=6,
rngs=nnx.Rngs(0),
)
y = module(jnp.ones((1, 7, 3)))
y = module(jnp.ones((1, 7, 3)), decode=False)
assert y.shape == (1, 7, 6)

def test_multihead_sow_attention_weights(self):
Expand Down Expand Up @@ -58,7 +58,8 @@ def __call__(self, x, sow_weights=False):
rng,
)

_ = module(x, True)
with nnx.flags(decode=False):
_ = module(x, True)
intermediates = module.pop(nnx.Intermediate)
assert intermediates['attention_layers/0/attention_weights'][0].shape == (
4,
Expand All @@ -74,7 +75,8 @@ def __call__(self, x, sow_weights=False):
6,
)

_ = module(x)
with nnx.flags(decode=False):
_ = module(x)
intermediates = module.pop(nnx.Intermediate)
assert not intermediates # empty

Expand All @@ -86,7 +88,7 @@ def test_autoregressive_decode_with_x64(self):
num_heads=2,
qkv_features=4,
decode=True,
rngs=nnx.Rngs(0)
rngs=nnx.Rngs(0),
)
module.init_cache(x.shape, dtype=x.dtype)
assert module.cached_key.shape == (1, 4, 2, 2)
Expand Down

0 comments on commit 5546627

Please sign in to comment.