Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inputs_k and inputs_v args to attention layer #3379

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/wmt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def __call__(
broadcast_dropout=False,
dropout_rate=config.attention_dropout_rate,
deterministic=config.deterministic,
)(y, encoded, encoder_decoder_mask)
)(y, encoded, mask=encoder_decoder_mask)

y = nn.Dropout(rate=config.dropout_rate)(
y, deterministic=config.deterministic
Expand Down
90 changes: 84 additions & 6 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Attention core modules for Flax."""

import functools
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union, overload
import warnings

from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
Expand Down Expand Up @@ -248,11 +249,37 @@ class MultiHeadDotProductAttention(Module):
qkv_dot_general_cls: Any = None
out_dot_general_cls: Any = None

@overload
def __call__(
self,
inputs_q: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
...

@overload
def __call__(
self,
inputs_q: Array,
*,
inputs_kv: Array = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
...

@compact
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
inputs_kv: Optional[Array] = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
):
Expand All @@ -261,9 +288,19 @@ def __call__(
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.

If both inputs_k and inputs_v are None, they will both copy the value of
inputs_q (self attention).
If only inputs_v is None, it will copy the value of inputs_k.

Args:
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
inputs_kv: key/values of shape `[batch_sizes..., length, features]`.
inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
inputs_k will copy the value of inputs_q.
inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
inputs_v will copy the value of inputs_k.
inputs_kv: key/values of shape `[batch_sizes..., length, features]`. If
None, inputs_kv will copy the value of inputs_q. This arg will be
deprecated soon. Use inputs_k and inputs_v instead.
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
key/value_length]`. Attention weights are masked out if their
corresponding mask value is `False`.
Expand All @@ -273,6 +310,42 @@ def __call__(
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
if inputs_kv is not None:
if inputs_k is not None or inputs_v is not None:
raise ValueError('If either `inputs_k` or `inputs_v` is not None, '
'`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` '
'and `inputs_v` must be None. We recommend using `inputs_k` and '
'`inputs_v` args, since `inputs_kv` will be deprecated soon. See '
'https://github.com/google/flax/discussions/3389 for more '
'information.')
inputs_k = inputs_v = inputs_kv
warnings.warn('The inputs_kv arg will be deprecated soon. '
'Use inputs_k and inputs_v instead. See '
'https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning)
else:
if inputs_k is None:
if inputs_v is not None:
raise ValueError('`inputs_k` cannot be None if `inputs_v` is not None. '
'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
'value to `inputs_k` and leave `inputs_v` as None.')
inputs_k = inputs_q
if inputs_v is None:
inputs_v = inputs_k
elif inputs_v.shape[-1] == inputs_v.shape[-2]:
warnings.warn(f"You are passing an array of shape {inputs_v.shape} "
"to the `inputs_v` arg, when you may have intended "
"to pass it to the `mask` arg. As of Flax version "
"0.7.4, the function signature of "
"MultiHeadDotProductAttention's `__call__` method "
"has changed to `__call__(inputs_q, inputs_k=None, "
"inputs_v=None, *, inputs_kv=None, mask=None, "
"deterministic=None)`. Use the kwarg `mask` instead. "
"See https://github.com/google/flax/discussions/3389 "
"and read the docstring for more information.",
DeprecationWarning)

features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
Expand All @@ -298,8 +371,8 @@ def __call__(
# dimensions are then [batch..., length, n_heads, n_features_per_head]
query, key, value = (
dense(name='query')(inputs_q),
dense(name='key')(inputs_kv),
dense(name='value')(inputs_kv),
dense(name='key')(inputs_k),
dense(name='value')(inputs_v),
)

if self.normalize_qk:
Expand Down Expand Up @@ -429,8 +502,13 @@ def __call__( # type: ignore
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
warnings.warn('SelfAttention will be deprecated soon. Use '
'`MultiHeadDotProductAttention.__call__(inputs_q)` instead. '
'See https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning)
return super().__call__(
inputs_q, inputs_q, mask, deterministic=deterministic
inputs_q, mask=mask, deterministic=deterministic
)


Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ filterwarnings = [
"ignore:.*module 'sre_constants' is deprecated.*:DeprecationWarning",
# DeprecationWarning: jax.random.KeyArray is deprecated.
"ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning",
# DeprecationWarning: SelfAttention will be deprecated soon.
"ignore:.*SelfAttention will be deprecated soon.*:DeprecationWarning",
# DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.
"ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning",
# DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed
"ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning"
]

[tool.coverage.report]
Expand Down
95 changes: 82 additions & 13 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from absl.testing import absltest
from absl.testing import parameterized

from flax import errors
from flax import linen as nn
from flax import jax_utils
from flax.core import pop
Expand Down Expand Up @@ -67,15 +68,14 @@ def test_dtype_infer(self):
def test_multihead_encoder_decoder_attention(self):
rng = random.key(0)
q = jnp.ones((4, 2, 3, 5))
kv = jnp.ones((4, 2, 3, 5))
sa_module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
bias_init=initializers.zeros,
deterministic=False,
)
y, _ = sa_module.init_with_output(rng, q, kv)
y, _ = sa_module.init_with_output(rng, q)
self.assertEqual(y.shape, q.shape)

def test_multihead_self_attention_w_dropout(self):
Expand All @@ -91,7 +91,7 @@ def test_multihead_self_attention_w_dropout(self):
)
rng1, rng2 = random.split(rng)
rngs = {'params': rng1, 'dropout': rng2}
y, _ = sa_module.init_with_output(rngs, x, x)
y, _ = sa_module.init_with_output(rngs, x)
self.assertEqual(y.shape, x.shape)

def test_multihead_self_attention_w_dropout_disabled(self):
Expand All @@ -108,11 +108,11 @@ def test_multihead_self_attention_w_dropout_disabled(self):
rng1, rng2, rng3, rng4 = random.split(rng, 4)
rngs1 = {'params': rng1, 'dropout': rng2}
rngs2 = {'params': rng3, 'dropout': rng4}
y1, vs = sa_module0.init_with_output(rngs1, x, x)
y2, _ = sa_module0.init_with_output(rngs2, x, x)
y1, vs = sa_module0.init_with_output(rngs1, x)
y2, _ = sa_module0.init_with_output(rngs2, x)
np.testing.assert_allclose(y1, y2)
y3 = sa_module0.apply(vs, x, x, rngs=rngs1)
y4 = sa_module0.apply(vs, x, x, rngs=rngs2)
y3 = sa_module0.apply(vs, x, rngs=rngs1)
y4 = sa_module0.apply(vs, x, rngs=rngs2)
np.testing.assert_allclose(y3, y4)
sa_module1 = nn.MultiHeadDotProductAttention(
num_heads=8,
Expand All @@ -121,8 +121,8 @@ def test_multihead_self_attention_w_dropout_disabled(self):
bias_init=initializers.zeros,
dropout_rate=0.0,
)
y5 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs1)
y6 = sa_module1.apply(vs, x, x, deterministic=True, rngs=rngs2)
y5 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs1)
y6 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs2)
np.testing.assert_allclose(y5, y6)
sa_module2 = nn.MultiHeadDotProductAttention(
num_heads=8,
Expand All @@ -131,8 +131,8 @@ def test_multihead_self_attention_w_dropout_disabled(self):
bias_init=initializers.zeros,
dropout_rate=0.5,
)
y7 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs1)
y8 = sa_module2.apply(vs, x, x, deterministic=True, rngs=rngs2)
y7 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs1)
y8 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs2)
np.testing.assert_allclose(y7, y8)

def test_causal_mask_1d(self):
Expand Down Expand Up @@ -204,11 +204,11 @@ def test_autoregresive_receptive_field_1d(self):
deterministic=False,
)

initial_vars = module.init(rng1, inputs, inputs)
initial_vars = module.init(rng1, inputs)
causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1]))

def model_loss(inputs, pos):
out = module.apply(initial_vars, inputs, inputs, causal_mask)
out = module.apply(initial_vars, inputs, mask=causal_mask)
assert out.shape == input_shape
assert len(out.shape) == 3
return out[0, pos, :].sum()
Expand All @@ -234,6 +234,75 @@ def get_receptive_field_1d(pos):
'autoregressive self-attention.'
)

def test_multihead_self_attention_equality(self):
rng = random.key(0)
q = jnp.ones((4, 2, 3, 5))
module_kwargs = {'num_heads': 8,
'qkv_features': 16,
'kernel_init': initializers.ones,
'bias_init': initializers.zeros,
'deterministic': False}
sa_module0 = nn.MultiHeadDotProductAttention(**module_kwargs)
sa_module1 = nn.SelfAttention(**module_kwargs)
y0, v0 = sa_module0.init_with_output(rng, q)
with self.assertWarnsRegex(DeprecationWarning, 'SelfAttention will be deprecated soon.'):
y1, v1 = sa_module1.init_with_output(rng, q)
self.assertTrue((y0 == y1).all())
self.assertTrue(jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), v0, v1)))

def test_multihead_kv_args(self):
key1, key2 = random.split(random.key(0), 2)
query = random.uniform(key1, (3, 5))
key_value = random.uniform(key1, (9, 5))
module = nn.MultiHeadDotProductAttention(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
bias_init=initializers.zeros,
deterministic=False,
)
y0, v0 = module.init_with_output(key2, query, inputs_k=key_value, inputs_v=key_value)
y1, v1 = module.init_with_output(key2, query, inputs_k=key_value)
with self.assertWarnsRegex(DeprecationWarning, 'The inputs_kv arg will be deprecated soon.'):
y2, v2 = module.init_with_output(key2, query, inputs_kv=key_value)
self.assertTrue((y0 == y1).all() and (y1 == y2).all())
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(lambda x, y, z: (x == y).all() and (y == z).all(),
v0, v1, v2)))

with self.assertRaisesRegex(ValueError, '`inputs_k` cannot be None if `inputs_v` is not None.'):
y3, v3 = module.init_with_output(key2, query, inputs_v=key_value)
with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'):
y3, v3 = module.init_with_output(key2, query, inputs_kv=key_value, inputs_v=key_value)
with self.assertRaisesRegex(ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.'):
y3, v3 = module.init_with_output(key2, query, key_value, key_value, inputs_kv=key_value)

def test_multihead_mask_warning(self):
rng = random.key(0)
rng1, rng2 = random.split(rng, num=2)

length = 10
dim = 1
num_heads = 1
input_shape = (1, length, dim)
query = key = random.normal(rng2, input_shape)

module = nn.MultiHeadDotProductAttention(
num_heads=num_heads,
kernel_init=jax.nn.initializers.ones,
deterministic=False,
)

initial_vars = module.init(rng1, query, key)
causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1]))

module.apply(initial_vars, query, key, mask=causal_mask)
with self.assertWarnsRegex(DeprecationWarning,
"the function signature of MultiHeadDotProductAttention's `__call__` method has changed"):
with self.assertRaises(errors.ScopeParamShapeError):
module.apply(initial_vars, query, key, causal_mask)


if __name__ == '__main__':
absltest.main()
Loading