From 1d41190b98f300e409edc054e9f6fdd2d398d443 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Thu, 28 Sep 2023 15:22:32 -0700 Subject: [PATCH] split inputs_kv arg in attention layer --- examples/wmt/models.py | 2 +- flax/linen/attention.py | 90 +++++++++++++++++++++++++-- pyproject.toml | 6 ++ tests/linen/linen_attention_test.py | 95 +++++++++++++++++++++++++---- 4 files changed, 173 insertions(+), 20 deletions(-) diff --git a/examples/wmt/models.py b/examples/wmt/models.py index 0f2fd2f962..6ed08ccd23 100644 --- a/examples/wmt/models.py +++ b/examples/wmt/models.py @@ -299,7 +299,7 @@ def __call__( broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, - )(y, encoded, encoder_decoder_mask) + )(y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=config.dropout_rate)( y, deterministic=config.deterministic diff --git a/flax/linen/attention.py b/flax/linen/attention.py index 851434f15a..575620efcd 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -15,7 +15,8 @@ """Attention core modules for Flax.""" import functools -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union, overload +import warnings from flax.linen import initializers from flax.linen.dtypes import promote_dtype @@ -248,11 +249,37 @@ class MultiHeadDotProductAttention(Module): qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None + @overload + def __call__( + self, + inputs_q: Array, + inputs_k: Optional[Array] = None, + inputs_v: Optional[Array] = None, + *, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): + ... + + @overload + def __call__( + self, + inputs_q: Array, + *, + inputs_kv: Array = None, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): + ... + @compact def __call__( self, inputs_q: Array, - inputs_kv: Array, + inputs_k: Optional[Array] = None, + inputs_v: Optional[Array] = None, + *, + inputs_kv: Optional[Array] = None, mask: Optional[Array] = None, deterministic: Optional[bool] = None, ): @@ -261,9 +288,19 @@ def __call__( Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. + If both inputs_k and inputs_v are None, they will both copy the value of + inputs_q (self attention). + If only inputs_v is None, it will copy the value of inputs_k. + Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. - inputs_kv: key/values of shape `[batch_sizes..., length, features]`. + inputs_k: key of shape `[batch_sizes..., length, features]`. If None, + inputs_k will copy the value of inputs_q. + inputs_v: values of shape `[batch_sizes..., length, features]`. If None, + inputs_v will copy the value of inputs_k. + inputs_kv: key/values of shape `[batch_sizes..., length, features]`. If + None, inputs_kv will copy the value of inputs_q. This arg will be + deprecated soon. Use inputs_k and inputs_v instead. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. @@ -273,6 +310,42 @@ def __call__( Returns: output of shape `[batch_sizes..., length, features]`. """ + if inputs_kv is not None: + if inputs_k is not None or inputs_v is not None: + raise ValueError('If either `inputs_k` or `inputs_v` is not None, ' + '`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` ' + 'and `inputs_v` must be None. We recommend using `inputs_k` and ' + '`inputs_v` args, since `inputs_kv` will be deprecated soon. See ' + 'https://github.com/google/flax/discussions/3389 for more ' + 'information.') + inputs_k = inputs_v = inputs_kv + warnings.warn('The inputs_kv arg will be deprecated soon. ' + 'Use inputs_k and inputs_v instead. See ' + 'https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning) + else: + if inputs_k is None: + if inputs_v is not None: + raise ValueError('`inputs_k` cannot be None if `inputs_v` is not None. ' + 'To have both `inputs_k` and `inputs_v` be the same value, pass in the ' + 'value to `inputs_k` and leave `inputs_v` as None.') + inputs_k = inputs_q + if inputs_v is None: + inputs_v = inputs_k + elif inputs_v.shape[-1] == inputs_v.shape[-2]: + warnings.warn(f"You are passing an array of shape {inputs_v.shape} " + "to the `inputs_v` arg, when you may have intended " + "to pass it to the `mask` arg. As of Flax version " + "0.7.4, the function signature of " + "MultiHeadDotProductAttention's `__call__` method " + "has changed to `__call__(inputs_q, inputs_k=None, " + "inputs_v=None, *, inputs_kv=None, mask=None, " + "deterministic=None)`. Use the kwarg `mask` instead. " + "See https://github.com/google/flax/discussions/3389 " + "and read the docstring for more information.", + DeprecationWarning) + features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( @@ -298,8 +371,8 @@ def __call__( # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = ( dense(name='query')(inputs_q), - dense(name='key')(inputs_kv), - dense(name='value')(inputs_kv), + dense(name='key')(inputs_k), + dense(name='value')(inputs_v), ) if self.normalize_qk: @@ -429,8 +502,13 @@ def __call__( # type: ignore Returns: output of shape `[batch_sizes..., length, features]`. """ + warnings.warn('SelfAttention will be deprecated soon. Use ' + '`MultiHeadDotProductAttention.__call__(inputs_q)` instead. ' + 'See https://github.com/google/flax/discussions/3389 ' + 'for more information.', + DeprecationWarning) return super().__call__( - inputs_q, inputs_q, mask, deterministic=deterministic + inputs_q, mask=mask, deterministic=deterministic ) diff --git a/pyproject.toml b/pyproject.toml index e65d5e8677..ea79d87cd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,12 @@ filterwarnings = [ "ignore:.*module 'sre_constants' is deprecated.*:DeprecationWarning", # DeprecationWarning: jax.random.KeyArray is deprecated. "ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning", + # DeprecationWarning: SelfAttention will be deprecated soon. + "ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning", + # DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead. + "ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning", + # DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed + "ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning" ] [tool.coverage.report] diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 2556b292c2..aba47516e4 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -17,6 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized +from flax import errors from flax import linen as nn from flax import jax_utils from flax.core import pop @@ -67,7 +68,6 @@ def test_dtype_infer(self): def test_multihead_encoder_decoder_attention(self): rng = random.key(0) q = jnp.ones((4, 2, 3, 5)) - kv = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, @@ -75,7 +75,7 @@ def test_multihead_encoder_decoder_attention(self): bias_init=initializers.zeros, deterministic=False, ) - y, _ = sa_module.init_with_output(rng, q, kv) + y, _ = sa_module.init_with_output(rng, q) self.assertEqual(y.shape, q.shape) def test_multihead_self_attention_w_dropout(self): @@ -91,7 +91,7 @@ def test_multihead_self_attention_w_dropout(self): ) rng1, rng2 = random.split(rng) rngs = {'params': rng1, 'dropout': rng2} - y, _ = sa_module.init_with_output(rngs, x, x) + y, _ = sa_module.init_with_output(rngs, x) self.assertEqual(y.shape, x.shape) def test_multihead_self_attention_w_dropout_disabled(self): @@ -108,11 +108,11 @@ def test_multihead_self_attention_w_dropout_disabled(self): rng1, rng2, rng3, rng4 = random.split(rng, 4) rngs1 = {'params': rng1, 'dropout': rng2} rngs2 = {'params': rng3, 'dropout': rng4} - y1, vs = sa_module0.init_with_output(rngs1, x, x) - y2, _ = sa_module0.init_with_output(rngs2, x, x) + y1, vs = sa_module0.init_with_output(rngs1, x) + y2, _ = sa_module0.init_with_output(rngs2, x) np.testing.assert_allclose(y1, y2) - y3 = sa_module0.apply(vs, x, x, rngs=rngs1) - y4 = sa_module0.apply(vs, x, x, rngs=rngs2) + y3 = sa_module0.apply(vs, x, rngs=rngs1) + y4 = sa_module0.apply(vs, x, rngs=rngs2) np.testing.assert_allclose(y3, y4) sa_module1 = nn.MultiHeadDotProductAttention( num_heads=8, @@ -121,8 +121,8 @@ def test_multihead_self_attention_w_dropout_disabled(self): bias_init=initializers.zeros, dropout_rate=0.0, ) - y5 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs1) - y6 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs2) + y5 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs1) + y6 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y5, y6) sa_module2 = nn.MultiHeadDotProductAttention( num_heads=8, @@ -131,8 +131,8 @@ def test_multihead_self_attention_w_dropout_disabled(self): bias_init=initializers.zeros, dropout_rate=0.5, ) - y7 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs1) - y8 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs2) + y7 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs1) + y8 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y7, y8) def test_causal_mask_1d(self): @@ -204,11 +204,11 @@ def test_autoregresive_receptive_field_1d(self): deterministic=False, ) - initial_vars = module.init(rng1, inputs, inputs) + initial_vars = module.init(rng1, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) def model_loss(inputs, pos): - out = module.apply(initial_vars, inputs, inputs, causal_mask) + out = module.apply(initial_vars, inputs, mask=causal_mask) assert out.shape == input_shape assert len(out.shape) == 3 return out[0, pos, :].sum() @@ -234,6 +234,75 @@ def get_receptive_field_1d(pos): 'autoregressive self-attention.' ) + def test_multihead_self_attention_equality(self): + rng = random.key(0) + q = jnp.ones((4, 2, 3, 5)) + module_kwargs = {'num_heads': 8, + 'qkv_features': 16, + 'kernel_init': initializers.ones, + 'bias_init': initializers.zeros, + 'deterministic': False} + sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs) + sa_module1 = nn.SelfAttention(**module_kwargs) + y0, v0 = sa_module0.init_with_output(rng, q) + with self.assertWarnsRegex(DeprecationWarning, 'SelfAttention will be deprecated soon.'): + y1, v1 = sa_module1.init_with_output(rng, q) + self.assertTrue((y0 == y1).all()) + self.assertTrue(jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1))) + + def test_multihead_kv_args(self): + key1, key2 = random.split(random.key(0), 2) + query = random.uniform(key1, (3, 5)) + key_value = random.uniform(key1, (9, 5)) + module = nn.MultiHeadDotProductAttention( + num_heads=8, + qkv_features=16, + kernel_init=initializers.ones, + bias_init=initializers.zeros, + deterministic=False, + ) + y0, v0 = module.init_with_output(key2, query, inputs_k=key_value, inputs_v=key_value) + y1, v1 = module.init_with_output(key2, query, inputs_k=key_value) + with self.assertWarnsRegex(DeprecationWarning, 'The inputs_kv arg will be deprecated soon.'): + y2, v2 = module.init_with_output(key2, query, inputs_kv=key_value) + self.assertTrue((y0 == y1).all() and (y1 == y2).all()) + self.assertTrue( + jax.tree_util.tree_all( + jax.tree_map(lambda x, y, z: (x == y).all() and (y == z).all(), + v0, v1, v2))) + + with self.assertRaisesRegex(ValueError, '`inputs_k` cannot be None if `inputs_v` is not None.'): + y3, v3 = module.init_with_output(key2, query, inputs_v=key_value) + with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'): + y3, v3 = module.init_with_output(key2, query, inputs_kv=key_value, inputs_v=key_value) + with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'): + y3, v3 = module.init_with_output(key2, query, key_value, key_value, inputs_kv=key_value) + + def test_multihead_mask_warning(self): + rng = random.key(0) + rng1, rng2 = random.split(rng, num=2) + + length = 10 + dim = 1 + num_heads = 1 + input_shape = (1, length, dim) + query = key = random.normal(rng2, input_shape) + + module = nn.MultiHeadDotProductAttention( + num_heads=num_heads, + kernel_init=jax.nn.initializers.ones, + deterministic=False, + ) + + initial_vars = module.init(rng1, query, key) + causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) + + module.apply(initial_vars, query, key, mask=causal_mask) + with self.assertWarnsRegex(DeprecationWarning, + "the function signature of MultiHeadDotProductAttention's `__call__` method has changed"): + with self.assertRaises(errors.ScopeParamShapeError): + module.apply(initial_vars, query, key, causal_mask) + if __name__ == '__main__': absltest.main()