diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index 47af5792a6..187252e30e 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -575,12 +575,21 @@ def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]): vars(node).update(items) -def first_from(arg_name: str, *args: tp.Optional[A]) -> A: - """Return the first non-None argument.""" +def first_from(*args: tp.Optional[A], error_msg: str) -> A: + """Return the first non-None argument. + + If all arguments are None, raise a ValueError with the given error message. + + Args: + *args: the arguments to check + error_msg: the error message to raise if all arguments are None + Returns: + The first non-None argument. + """ for arg in args: if arg is not None: return arg - raise ValueError(f'No non-None arguments found for {arg_name!r}') + raise ValueError(error_msg) def merge( diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/experimental/nnx/nnx/nn/attention.py index 56135a8a90..70a63bd5a0 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/experimental/nnx/nnx/nn/attention.py @@ -25,7 +25,7 @@ from flax.experimental import nnx from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx.flaglib import flags +from flax.experimental.nnx.nnx import flaglib from flax.experimental.nnx.nnx.module import Module, first_from from flax.experimental.nnx.nnx.nn import initializers from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype @@ -305,7 +305,7 @@ def __init__( bias_init: initializers.Initializer = initializers.zeros_init(), use_bias: bool = True, attention_fn: Callable[..., Array] = dot_product_attention, - decode: bool = False, + decode: bool | None = None, normalize_qk: bool = False, # Deprecated, will be removed. qkv_dot_general: DotGeneralT | None = None, @@ -402,6 +402,7 @@ def __call__( dropout_rng: Optional[Array] = None, rngs: rnglib.Rngs | None = None, sow_weights: bool = False, + decode: bool | None = None, ): ... @@ -416,6 +417,7 @@ def __call__( dropout_rng: Array | None = None, rngs: rnglib.Rngs | None = None, sow_weights: bool = False, + decode: bool | None = None, ): ... @@ -429,6 +431,7 @@ def __call__( deterministic: bool | None = None, rngs: rnglib.Rngs | None = None, sow_weights: bool = False, + decode: bool | None = None, ): """Applies multi-head dot product attention on the input data. @@ -490,7 +493,15 @@ def __call__( # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. - if self.decode: + decode = first_from( + decode, + self.decode, + flaglib.flags.get('decode'), + error_msg="""No `decode` argument was provided to MultiHeadAttention + as either a __call__ argument, class attribute, or nnx.flag.""", + ) + + if decode: ( *batch_dims, max_length, @@ -530,10 +541,11 @@ def __call__( self.dropout_rate > 0.0 ): # Require `deterministic` only if using dropout. deterministic = first_from( - 'deterministic', deterministic, self.deterministic, - flags.get('deterministic'), + flaglib.flags.get('deterministic'), + error_msg="""No `deterministic` argument was provided to MultiHeadAttention + as either a __call__ argument, class attribute, or nnx.flag.""", ) if not deterministic: if rngs is None: diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/experimental/nnx/nnx/nn/normalization.py index f8a42e78ec..7bb84a5008 100644 --- a/flax/experimental/nnx/nnx/nn/normalization.py +++ b/flax/experimental/nnx/nnx/nn/normalization.py @@ -258,10 +258,11 @@ def __call__( """ use_running_average = first_from( - 'use_running_average', use_running_average, self.use_running_average, flaglib.flags.get('use_running_average'), + error_msg="""No `use_running_average` argument was provided to BatchNorm + as either a __call__ argument, class attribute, or nnx.flag.""", ) feature_axes = _canonicalize_axes(x.ndim, self.axis) reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py index bc114ce359..4bc59f671c 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -59,10 +59,11 @@ def __call__( The masked inputs reweighted to preserve mean. """ deterministic = first_from( - 'deterministic', deterministic, self.deterministic, flaglib.flags.get('deterministic'), + error_msg="""No `deterministic` argument was provided to Dropout + as either a __call__ argument, class attribute, or nnx.flag.""", ) if (self.rate == 0.0) or deterministic: diff --git a/flax/experimental/nnx/tests/nn/test_attention.py b/flax/experimental/nnx/tests/nn/test_attention.py index d85b00cfc2..e2ab44b430 100644 --- a/flax/experimental/nnx/tests/nn/test_attention.py +++ b/flax/experimental/nnx/tests/nn/test_attention.py @@ -26,7 +26,7 @@ def test_basic(self): out_features=6, rngs=nnx.Rngs(0), ) - y = module(jnp.ones((1, 7, 3))) + y = module(jnp.ones((1, 7, 3)), decode=False) assert y.shape == (1, 7, 6) def test_multihead_sow_attention_weights(self): @@ -58,7 +58,8 @@ def __call__(self, x, sow_weights=False): rng, ) - _ = module(x, True) + with nnx.flags(decode=False): + _ = module(x, True) intermediates = module.pop(nnx.Intermediate) assert intermediates['attention_layers/0/attention_weights'][0].shape == ( 4, @@ -74,7 +75,8 @@ def __call__(self, x, sow_weights=False): 6, ) - _ = module(x) + with nnx.flags(decode=False): + _ = module(x) intermediates = module.pop(nnx.Intermediate) assert not intermediates # empty @@ -86,7 +88,7 @@ def test_autoregressive_decode_with_x64(self): num_heads=2, qkv_features=4, decode=True, - rngs=nnx.Rngs(0) + rngs=nnx.Rngs(0), ) module.init_cache(x.shape, dtype=x.dtype) assert module.cached_key.shape == (1, 4, 2, 2)