Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
f6a222c by Marcus Chiam <marcuschiam@google.com>:

split inputs_kv arg in attention layer

COPYBARA_INTEGRATE_REVIEW=#3379 from chiamp:attention f6a222c
PiperOrigin-RevId: 572671273
  • Loading branch information
chiamp authored and Flax Authors committed Oct 11, 2023
1 parent 5a1fbc2 commit b76f487
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/wmt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def __call__(
broadcast_dropout=False,
dropout_rate=config.attention_dropout_rate,
deterministic=config.deterministic,
)(y, encoded, encoder_decoder_mask)
)(y, encoded, mask=encoder_decoder_mask)

y = nn.Dropout(rate=config.dropout_rate)(
y, deterministic=config.deterministic
Expand Down
90 changes: 84 additions & 6 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Attention core modules for Flax."""

import functools
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union, overload
import warnings

from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
Expand Down Expand Up @@ -248,11 +249,37 @@ class MultiHeadDotProductAttention(Module):
qkv_dot_general_cls: Any = None
out_dot_general_cls: Any = None

@overload
def __call__(
self,
inputs_q: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
...

@overload
def __call__(
self,
inputs_q: Array,
*,
inputs_kv: Array = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
...

@compact
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
inputs_kv: Optional[Array] = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
Expand All @@ -261,9 +288,19 @@ def __call__(
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
If both inputs_k and inputs_v are None, they will both copy the value of
inputs_q (self attention).
If only inputs_v is None, it will copy the value of inputs_k.
Args:
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
inputs_kv: key/values of shape `[batch_sizes..., length, features]`.
inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
inputs_k will copy the value of inputs_q.
inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
inputs_v will copy the value of inputs_k.
inputs_kv: key/values of shape `[batch_sizes..., length, features]`. If
None, inputs_kv will copy the value of inputs_q. This arg will be
deprecated soon. Use inputs_k and inputs_v instead.
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
key/value_length]`. Attention weights are masked out if their
corresponding mask value is `False`.
Expand All @@ -273,6 +310,42 @@ def __call__(
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
if inputs_kv is not None:
if inputs_k is not None or inputs_v is not None:
raise ValueError('If either `inputs_k` or `inputs_v` is not None, '
'`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` '
'and `inputs_v` must be None. We recommend using `inputs_k` and '
'`inputs_v` args, since `inputs_kv` will be deprecated soon. See '
'https://github.com/google/flax/discussions/3389 for more '
'information.')
inputs_k = inputs_v = inputs_kv
warnings.warn('The inputs_kv arg will be deprecated soon. '
'Use inputs_k and inputs_v instead. See '
'https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning)
else:
if inputs_k is None:
if inputs_v is not None:
raise ValueError('`inputs_k` cannot be None if `inputs_v` is not None. '
'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
'value to `inputs_k` and leave `inputs_v` as None.')
inputs_k = inputs_q
if inputs_v is None:
inputs_v = inputs_k
elif inputs_v.shape[-1] == inputs_v.shape[-2]:
warnings.warn(f"You are passing an array of shape {inputs_v.shape} "
"to the `inputs_v` arg, when you may have intended "
"to pass it to the `mask` arg. As of Flax version "
"0.7.4, the function signature of "
"MultiHeadDotProductAttention's `__call__` method "
"has changed to `__call__(inputs_q, inputs_k=None, "
"inputs_v=None, *, inputs_kv=None, mask=None, "
"deterministic=None)`. Use the kwarg `mask` instead. "
"See https://github.com/google/flax/discussions/3389 "
"and read the docstring for more information.",
DeprecationWarning)

features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
Expand All @@ -298,8 +371,8 @@ def __call__(
# dimensions are then [batch..., length, n_heads, n_features_per_head]
query, key, value = (
dense(name='query')(inputs_q),
dense(name='key')(inputs_kv),
dense(name='value')(inputs_kv),
dense(name='key')(inputs_k),
dense(name='value')(inputs_v),
)

if self.normalize_qk:
Expand Down Expand Up @@ -429,8 +502,13 @@ def __call__( # type: ignore
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
warnings.warn('SelfAttention will be deprecated soon. Use '
'`MultiHeadDotProductAttention.__call__(inputs_q)` instead. '
'See https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning)
return super().__call__(
inputs_q, inputs_q, mask, deterministic=deterministic
inputs_q, mask=mask, deterministic=deterministic
)


Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ filterwarnings = [
"ignore:.*module 'sre_constants' is deprecated.*:DeprecationWarning",
# DeprecationWarning: jax.random.KeyArray is deprecated.
"ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning",
# DeprecationWarning: SelfAttention will be deprecated soon.
"ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning",
# DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.
"ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning",
# DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed
"ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning"
# DeprecationWarning: ml_dtypes.float8_e4m3b11 is deprecated.
"ignore:.*ml_dtypes.float8_e4m3b11 is deprecated.*:DeprecationWarning",
]
Expand Down
95 changes: 82 additions & 13 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from absl.testing import absltest
from absl.testing import parameterized

from flax import errors
from flax import linen as nn
from flax import jax_utils
from flax.core import pop
Expand Down Expand Up @@ -67,15 +68,14 @@ def test_dtype_infer(self):
def test_multihead_encoder_decoder_attention(self):
rng = random.key(0)
q = jnp.ones((4, 2, 3, 5))
kv = jnp.ones((4, 2, 3, 5))
sa_module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
bias_init=initializers.zeros,
deterministic=False,
)
y, _ = sa_module.init_with_output(rng, q, kv)
y, _ = sa_module.init_with_output(rng, q)
self.assertEqual(y.shape, q.shape)

def test_multihead_self_attention_w_dropout(self):
Expand All @@ -91,7 +91,7 @@ def test_multihead_self_attention_w_dropout(self):
)
rng1, rng2 = random.split(rng)
rngs = {'params': rng1, 'dropout': rng2}
y, _ = sa_module.init_with_output(rngs, x, x)
y, _ = sa_module.init_with_output(rngs, x)
self.assertEqual(y.shape, x.shape)

def test_multihead_self_attention_w_dropout_disabled(self):
Expand All @@ -108,11 +108,11 @@ def test_multihead_self_attention_w_dropout_disabled(self):
rng1, rng2, rng3, rng4 = random.split(rng, 4)
rngs1 = {'params': rng1, 'dropout': rng2}
rngs2 = {'params': rng3, 'dropout': rng4}
y1, vs = sa_module0.init_with_output(rngs1, x, x)
y2, _ = sa_module0.init_with_output(rngs2, x, x)
y1, vs = sa_module0.init_with_output(rngs1, x)
y2, _ = sa_module0.init_with_output(rngs2, x)
np.testing.assert_allclose(y1, y2)
y3 = sa_module0.apply(vs, x, x, rngs=rngs1)
y4 = sa_module0.apply(vs, x, x, rngs=rngs2)
y3 = sa_module0.apply(vs, x, rngs=rngs1)
y4 = sa_module0.apply(vs, x, rngs=rngs2)
np.testing.assert_allclose(y3, y4)
sa_module1 = nn.MultiHeadDotProductAttention(
num_heads=8,
Expand All @@ -121,8 +121,8 @@ def test_multihead_self_attention_w_dropout_disabled(self):
bias_init=initializers.zeros,
dropout_rate=0.0,
)
y5 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs1)
y6 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs2)
y5 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs1)
y6 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs2)
np.testing.assert_allclose(y5, y6)
sa_module2 = nn.MultiHeadDotProductAttention(
num_heads=8,
Expand All @@ -131,8 +131,8 @@ def test_multihead_self_attention_w_dropout_disabled(self):
bias_init=initializers.zeros,
dropout_rate=0.5,
)
y7 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs1)
y8 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs2)
y7 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs1)
y8 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs2)
np.testing.assert_allclose(y7, y8)

def test_causal_mask_1d(self):
Expand Down Expand Up @@ -204,11 +204,11 @@ def test_autoregresive_receptive_field_1d(self):
deterministic=False,
)

initial_vars = module.init(rng1, inputs, inputs)
initial_vars = module.init(rng1, inputs)
causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1]))

def model_loss(inputs, pos):
out = module.apply(initial_vars, inputs, inputs, causal_mask)
out = module.apply(initial_vars, inputs, mask=causal_mask)
assert out.shape == input_shape
assert len(out.shape) == 3
return out[0, pos, :].sum()
Expand All @@ -234,6 +234,75 @@ def get_receptive_field_1d(pos):
'autoregressive self-attention.'
)

def test_multihead_self_attention_equality(self):
rng = random.key(0)
q = jnp.ones((4, 2, 3, 5))
module_kwargs = {'num_heads': 8,
'qkv_features': 16,
'kernel_init': initializers.ones,
'bias_init': initializers.zeros,
'deterministic': False}
sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs)
sa_module1 = nn.SelfAttention(**module_kwargs)
y0, v0 = sa_module0.init_with_output(rng, q)
with self.assertWarnsRegex(DeprecationWarning, 'SelfAttention will be deprecated soon.'):
y1, v1 = sa_module1.init_with_output(rng, q)
self.assertTrue((y0 == y1).all())
self.assertTrue(jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1)))

def test_multihead_kv_args(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
bias_init=initializers.zeros,
deterministic=False,
)
y0, v0 = module.init_with_output(key2, query, inputs_k=key_value, inputs_v=key_value)
y1, v1 = module.init_with_output(key2, query, inputs_k=key_value)
with self.assertWarnsRegex(DeprecationWarning, 'The inputs_kv arg will be deprecated soon.'):
y2, v2 = module.init_with_output(key2, query, inputs_kv=key_value)
self.assertTrue((y0 == y1).all() and (y1 == y2).all())
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(lambda x, y, z: (x == y).all() and (y == z).all(),
v0, v1, v2)))

with self.assertRaisesRegex(ValueError, '`inputs_k` cannot be None if `inputs_v` is not None.'):
y3, v3 = module.init_with_output(key2, query, inputs_v=key_value)
with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'):
y3, v3 = module.init_with_output(key2, query, inputs_kv=key_value, inputs_v=key_value)
with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'):
y3, v3 = module.init_with_output(key2, query, key_value, key_value, inputs_kv=key_value)

def test_multihead_mask_warning(self):
rng = random.key(0)
rng1, rng2 = random.split(rng, num=2)

length = 10
dim = 1
num_heads = 1
input_shape = (1, length, dim)
query = key = random.normal(rng2, input_shape)

module = nn.MultiHeadDotProductAttention(
num_heads=num_heads,
kernel_init=jax.nn.initializers.ones,
deterministic=False,
)

initial_vars = module.init(rng1, query, key)
causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1]))

module.apply(initial_vars, query, key, mask=causal_mask)
with self.assertWarnsRegex(DeprecationWarning,
"the function signature of MultiHeadDotProductAttention's `__call__` method has changed"):
with self.assertRaises(errors.ScopeParamShapeError):
module.apply(initial_vars, query, key, causal_mask)


if __name__ == '__main__':
absltest.main()

0 comments on commit b76f487

Please sign in to comment.