diff --git a/CHANGELOG.md b/CHANGELOG.md index 231f9660288..4c5f806faed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `self.ddp_accelerator` during distributed training. This is useful when, for example, instantiating submodules in your model's `__init__()` method by wrapping them with `self.ddp_accelerator.wrap_module()`. See the `allennlp.modules.transformer.t5` for an example. -- Added Tango components, to be explored in detail in a later post. +- Added Tango components, to be explored in detail in a later post +- Added `ScaledDotProductMatrixAttention`, and converted the transformer toolkit to use it +- Added tests to ensure that all `Attention` and `MatrixAttention` implementations are interchangeable ### Fixed @@ -34,7 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `TransformerTextField` can now take tensors of shape `(1, n)` like the tensors produced from a HuggingFace tokenizer. - `tqdm` lock is now set inside `MultiProcessDataLoading` when new workers are spawned to avoid contention when writing output. - `ConfigurationError` is now pickleable. -- Multitask models now support `TextFieldTensor` in heads, not just in the backbone +- Multitask models now support `TextFieldTensor` in heads, not just in the backbone. +- Fixed the signature of `ScaledDotProductAttention` to match the other `Attention` classes ### Changed diff --git a/allennlp/modules/attention/scaled_dot_product_attention.py b/allennlp/modules/attention/scaled_dot_product_attention.py index 36ecf592887..7c28acfe53d 100644 --- a/allennlp/modules/attention/scaled_dot_product_attention.py +++ b/allennlp/modules/attention/scaled_dot_product_attention.py @@ -1,11 +1,15 @@ import math +from typing import Optional + import torch from overrides import overrides + +from allennlp.modules.attention.dot_product_attention import DotProductAttention from allennlp.modules.attention.attention import Attention @Attention.register("scaled_dot_product") -class ScaledDotProductAttention(Attention): +class ScaledDotProductAttention(DotProductAttention): """ Computes attention between two tensors using scaled dot product. # Reference: [Attention Is All You Need (Vaswani et al, 2017)] @@ -22,12 +26,13 @@ class ScaledDotProductAttention(Attention): distribution for your attention. If false, this is just computing a similarity score. """ - def __init__(self, scaling_factor: int, normalize: bool = True) -> None: + def __init__(self, scaling_factor: Optional[int] = None, normalize: bool = True) -> None: super().__init__(normalize) self.scaling_factor = scaling_factor @overrides def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor: - scores = torch.matmul(vector, matrix) - scores = scores / math.sqrt(self.scaling_factor) + scores = super()._forward_internal(vector, matrix) + scaling_factor = self.scaling_factor or matrix.size(-1) + scores = scores / math.sqrt(scaling_factor) return scores diff --git a/allennlp/modules/matrix_attention/__init__.py b/allennlp/modules/matrix_attention/__init__.py index 4807383db9d..7bfdbfb7d53 100644 --- a/allennlp/modules/matrix_attention/__init__.py +++ b/allennlp/modules/matrix_attention/__init__.py @@ -2,4 +2,7 @@ from allennlp.modules.matrix_attention.bilinear_matrix_attention import BilinearMatrixAttention from allennlp.modules.matrix_attention.cosine_matrix_attention import CosineMatrixAttention from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention +from allennlp.modules.matrix_attention.scaled_dot_product_matrix_attention import ( + ScaledDotProductMatrixAttention, +) from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention diff --git a/allennlp/modules/matrix_attention/dot_product_matrix_attention.py b/allennlp/modules/matrix_attention/dot_product_matrix_attention.py index b13ed5f9851..a16e18e2682 100644 --- a/allennlp/modules/matrix_attention/dot_product_matrix_attention.py +++ b/allennlp/modules/matrix_attention/dot_product_matrix_attention.py @@ -15,4 +15,4 @@ class DotProductMatrixAttention(MatrixAttention): @overrides def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: - return matrix_1.bmm(matrix_2.transpose(2, 1)) + return matrix_1.matmul(matrix_2.transpose(-1, -2)) diff --git a/allennlp/modules/matrix_attention/matrix_attention.py b/allennlp/modules/matrix_attention/matrix_attention.py index 7bec2cec4ed..783923567aa 100644 --- a/allennlp/modules/matrix_attention/matrix_attention.py +++ b/allennlp/modules/matrix_attention/matrix_attention.py @@ -8,7 +8,7 @@ class MatrixAttention(torch.nn.Module, Registrable): `MatrixAttention` takes two matrices as input and returns a matrix of attentions. We compute the similarity between each row in each matrix and return unnormalized similarity - scores. Because these scores are unnormalized, we don't take a mask as input; it's up to the + scores. Because these scores are unnormalized, we don't take a mask as input; it's up to the caller to deal with masking properly when this output is used. Input: diff --git a/allennlp/modules/matrix_attention/scaled_dot_product_matrix_attention.py b/allennlp/modules/matrix_attention/scaled_dot_product_matrix_attention.py new file mode 100644 index 00000000000..5854c4d0c91 --- /dev/null +++ b/allennlp/modules/matrix_attention/scaled_dot_product_matrix_attention.py @@ -0,0 +1,21 @@ +import math + +import torch +from overrides import overrides + +from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention +from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention + + +@MatrixAttention.register("scaled_dot_product") +class ScaledDotProductMatrixAttention(DotProductMatrixAttention): + """ + Computes attention between every entry in matrix_1 with every entry in matrix_2 using a dot + product. Scales the result by the size of the embeddings. + + Registered as a `MatrixAttention` with name "scaled_dot_product". + """ + + @overrides + def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: + return super().forward(matrix_1, matrix_2) / math.sqrt(matrix_1.size(-1)) diff --git a/allennlp/modules/transformer/attention_module.py b/allennlp/modules/transformer/attention_module.py index 4d98caba4b2..e0be3e99bbc 100644 --- a/allennlp/modules/transformer/attention_module.py +++ b/allennlp/modules/transformer/attention_module.py @@ -6,7 +6,7 @@ from allennlp.common import FromParams from allennlp.common.checks import ConfigurationError -from allennlp.modules.attention import Attention +from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.util import apply_mask, FloatT, IntT, BoolT @@ -51,7 +51,7 @@ class AttentionModule(TransformerModule, FromParams): scoring_func: `str` (default = `scaled_dot_product`) The name of the attention-calculating function to be used. Eg. `additive`, `linear`, etc. For a complete list, please check - :mod:`allennlp.modules.attention.attention`. + :mod:`allennlp.modules.matrix_attention.matrix_attention`. output_linear: `bool` (default = `False`) Whether to add an additional output linear layer at the end. dropout: `float` (default = `0.0`) @@ -113,12 +113,7 @@ def __init__( self.output = torch.nn.Linear(self.all_head_size, hidden_size, bias=bias) self.scoring_func = scoring_func - if self.scoring_func in ["additive", "linear", "bilinear"]: - self.attn = Attention.by_name(self.scoring_func)(hidden_size, hidden_size) - elif self.scoring_func == "scaled_dot_product": - self.attn = Attention.by_name(self.scoring_func)(self.attention_head_size, False) - else: - self.attn = Attention.by_name(self.scoring_func)() + self.attn = MatrixAttention.by_name(self.scoring_func)() self.relative_attention_num_buckets = relative_attention_num_buckets @@ -229,7 +224,7 @@ def _get_attention_probs( past_key_states: Optional[torch.Tensor] = None, **kwargs, ): - attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) + attention_scores = self.attn(query_layer, key_layer) position_bias = self._position_bias( position_bias, seq_lengths, past_key_states, attention_scores @@ -478,7 +473,7 @@ def __init__( attention_head_size=key_value_proj_dim, num_attention_heads=num_heads, output_linear=True, - scoring_func="scaled_dot_product", + scoring_func="dot_product", dropout=dropout, bias=False, normalize_weights=normalize, @@ -487,8 +482,6 @@ def __init__( relative_attention_num_buckets=relative_attention_num_buckets, ) - self.attn = Attention.by_name(self.scoring_func)(scaling_factor=1, normalize=False) - def forward( # type: ignore self, hidden_states: torch.Tensor, diff --git a/allennlp/modules/transformer/bimodal_attention.py b/allennlp/modules/transformer/bimodal_attention.py index cdec6223506..cb5259205ab 100644 --- a/allennlp/modules/transformer/bimodal_attention.py +++ b/allennlp/modules/transformer/bimodal_attention.py @@ -1,7 +1,7 @@ import torch from allennlp.common import FromParams -from allennlp.modules.attention import Attention +from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention from allennlp.modules.transformer.transformer_module import TransformerModule from allennlp.modules.transformer.util import apply_mask @@ -44,7 +44,8 @@ class BiModalAttention(TransformerModule, FromParams): The name of the attention-calculating function to be used for the first modality. scoring_func2 : `str` (default = `scaled_dot_product`) The name of the attention-calculating function to be used for the second modality. - Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`. + Eg. `dot_product`, `linear`, etc. For a complete list, please check + :mod:`allennlp.modules.matrix_attention`. """ def __init__( @@ -79,13 +80,7 @@ def __init__( self.value1 = torch.nn.Linear(hidden_size1, self.all_head_size) self.scoring_func1 = scoring_func1 - if self.scoring_func1 in ["additive", "linear", "bilinear"]: - self.attn1 = Attention.by_name(self.scoring_func1)(hidden_size1, hidden_size1) - elif self.scoring_func1 == "scaled_dot_product": - self.attn1 = Attention.by_name(self.scoring_func1)(self.attention_head_size, False) - else: - self.attn1 = Attention.by_name(self.scoring_func1)() - + self.attn1 = MatrixAttention.by_name(self.scoring_func1)() self.dropout1 = torch.nn.Dropout(dropout1) # Second modality: @@ -95,13 +90,7 @@ def __init__( self.value2 = torch.nn.Linear(hidden_size2, self.all_head_size) self.scoring_func2 = scoring_func2 - if self.scoring_func2 in ["additive", "linear", "bilinear"]: - self.attn2 = Attention.by_name(self.scoring_func2)(hidden_size2, hidden_size2) - elif self.scoring_func2 == "scaled_dot_product": - self.attn2 = Attention.by_name(self.scoring_func2)(self.attention_head_size, False) - else: - self.attn2 = Attention.by_name(self.scoring_func2)() - + self.attn2 = MatrixAttention.by_name(self.scoring_func2)() self.dropout2 = torch.nn.Dropout(dropout2) def _transpose_for_scores(self, x): @@ -164,7 +153,7 @@ def forward( value_layer2 = self._transpose_for_scores(mixed_value_layer2) # Conditioning the second modality on the first one. - attention_scores1 = self.attn1(query_layer2, key_layer1.transpose(-1, -2)) + attention_scores1 = self.attn1(query_layer2, key_layer1) if attention_mask1 is not None: attention_scores1 = apply_mask(attention_scores1, attention_mask1) if co_attention_mask is not None: @@ -182,7 +171,7 @@ def forward( context_layer1 = context_layer1.view(*new_context_layer_shape1) # Conditioning the first modality on the second one. - attention_scores2 = self.attn2(query_layer1, key_layer2.transpose(-1, -2)) + attention_scores2 = self.attn2(query_layer1, key_layer2) # we can comment this line for single flow. if attention_mask2 is not None: attention_scores2 = apply_mask(attention_scores2, attention_mask2) diff --git a/tests/modules/attention/attention_test.py b/tests/modules/attention/attention_test.py new file mode 100644 index 00000000000..dcd91ce0383 --- /dev/null +++ b/tests/modules/attention/attention_test.py @@ -0,0 +1,21 @@ +import pytest +import torch + +from allennlp.modules import Attention +from allennlp.modules.attention import BilinearAttention, AdditiveAttention, LinearAttention + + +@pytest.mark.parametrize("attention_type", Attention.list_available()) +def test_all_attention_works_the_same(attention_type: str): + module_cls = Attention.by_name(attention_type) + + vector = torch.FloatTensor([[-7, -8, -9]]) + matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]]]) + + if module_cls in {BilinearAttention, AdditiveAttention, LinearAttention}: + module = module_cls(vector.size(-1), matrix.size(-1)) + else: + module = module_cls() + + output = module(vector, matrix) + assert tuple(output.size()) == (1, 2) diff --git a/tests/modules/attention/dot_product_attention_test.py b/tests/modules/attention/dot_product_attention_test.py index b91a68e286c..b14e9464d9a 100644 --- a/tests/modules/attention/dot_product_attention_test.py +++ b/tests/modules/attention/dot_product_attention_test.py @@ -14,8 +14,8 @@ def test_can_init_dot(self): isinstance(legacy_attention, DotProductAttention) def test_dot_product_similarity(self): - linear = DotProductAttention(normalize=False) - output = linear( + attn = DotProductAttention(normalize=False) + output = attn( torch.FloatTensor([[0, 0, 0], [1, 1, 1]]), torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), ) diff --git a/tests/modules/attention/scaled_dot_product_attention_test.py b/tests/modules/attention/scaled_dot_product_attention_test.py index 247cafc200d..866f0a22e30 100644 --- a/tests/modules/attention/scaled_dot_product_attention_test.py +++ b/tests/modules/attention/scaled_dot_product_attention_test.py @@ -18,13 +18,7 @@ def test_can_init_scaled_dot(self): def test_scaled_dot_product_similarity(self): attn = ScaledDotProductAttention(9, normalize=False) vector = torch.FloatTensor([[0, 0, 0], [1, 1, 1]]) - matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]).transpose( - -1, -2 - ) + matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) output = attn(vector, matrix) - assert_almost_equal( - output.numpy(), - numpy.array([[[0.0, 0.0], [2.0, 5.0]], [[0.0, 0.0], [8.0, 11.0]]]), - decimal=2, - ) + assert_almost_equal(output.numpy(), numpy.array([[0.0, 0.0], [8.0, 11.0]]), decimal=2) diff --git a/tests/modules/matrix_attention/matrix_attention_test.py b/tests/modules/matrix_attention/matrix_attention_test.py new file mode 100644 index 00000000000..c87a036babe --- /dev/null +++ b/tests/modules/matrix_attention/matrix_attention_test.py @@ -0,0 +1,21 @@ +import pytest +import torch + +from allennlp.modules import MatrixAttention +from allennlp.modules.matrix_attention import BilinearMatrixAttention, LinearMatrixAttention + + +@pytest.mark.parametrize("attention_type", MatrixAttention.list_available()) +def test_all_attention_works_the_same(attention_type: str): + module_cls = MatrixAttention.by_name(attention_type) + + matrix1 = torch.FloatTensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) + matrix2 = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]]]) + + if module_cls in {BilinearMatrixAttention, LinearMatrixAttention}: + module = module_cls(matrix1.size(-1), matrix2.size(-1)) + else: + module = module_cls() + + output = module(matrix1, matrix2) + assert tuple(output.size()) == (1, 3, 2) diff --git a/tests/modules/matrix_attention/scaled_dot_product_matrix_attention_test.py b/tests/modules/matrix_attention/scaled_dot_product_matrix_attention_test.py new file mode 100644 index 00000000000..7767a1ccedb --- /dev/null +++ b/tests/modules/matrix_attention/scaled_dot_product_matrix_attention_test.py @@ -0,0 +1,37 @@ +import math + +import torch +from numpy.testing import assert_almost_equal +import numpy + +from allennlp.common import Params +from allennlp.common.testing.test_case import AllenNlpTestCase +from allennlp.modules.matrix_attention import ScaledDotProductMatrixAttention +from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention + + +class TestScaledDotProductMatrixAttention(AllenNlpTestCase): + def test_can_init_dot(self): + legacy_attention = MatrixAttention.from_params(Params({"type": "scaled_dot_product"})) + isinstance(legacy_attention, ScaledDotProductMatrixAttention) + + def test_dot_product_similarity(self): + # example use case: a batch of size 2, + # with a time element component (e.g. sentences of length 2) each word is a vector of length 3. + # it is comparing this with another input of the same type + output = ScaledDotProductMatrixAttention()( + torch.FloatTensor([[[0, 0, 0], [4, 5, 6]], [[-7, -8, -9], [10, 11, 12]]]), + torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), + ) + + # for the first batch there is + # no correlation between the first words of the input matrix + # but perfect correlation for the second word + # for the second batch there is + # negative correlation for the first words + # a correlation for the second word + assert_almost_equal( + output.numpy(), + numpy.array([[[0, 0], [32, 77]], [[-194, -266], [266, 365]]]) / math.sqrt(3), + decimal=2, + )