Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Make sure that all attention works the same #5360

Merged
merged 8 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`self.ddp_accelerator` during distributed training. This is useful when, for example, instantiating submodules in your
model's `__init__()` method by wrapping them with `self.ddp_accelerator.wrap_module()`. See the `allennlp.modules.transformer.t5`
for an example.
- Added Tango components, to be explored in detail in a later post.
- Added Tango components, to be explored in detail in a later post
- Added `ScaledDotProductMatrixAttention`, and converted the transformer toolkit to use it
- Added tests to ensure that all `Attention` and `MatrixAttention` implementations are interchangeable

### Fixed

Expand All @@ -34,7 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `TransformerTextField` can now take tensors of shape `(1, n)` like the tensors produced from a HuggingFace tokenizer.
- `tqdm` lock is now set inside `MultiProcessDataLoading` when new workers are spawned to avoid contention when writing output.
- `ConfigurationError` is now pickleable.
- Multitask models now support `TextFieldTensor` in heads, not just in the backbone
- Multitask models now support `TextFieldTensor` in heads, not just in the backbone.
- Fixed the signature of `ScaledDotProductAttention` to match the other `Attention` classes

### Changed

Expand Down
13 changes: 9 additions & 4 deletions allennlp/modules/attention/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import math
from typing import Optional

import torch
from overrides import overrides

from allennlp.modules.attention.dot_product_attention import DotProductAttention
from allennlp.modules.attention.attention import Attention


@Attention.register("scaled_dot_product")
class ScaledDotProductAttention(Attention):
class ScaledDotProductAttention(DotProductAttention):
"""
Computes attention between two tensors using scaled dot product.
# Reference: [Attention Is All You Need (Vaswani et al, 2017)]
Expand All @@ -22,12 +26,13 @@ class ScaledDotProductAttention(Attention):
distribution for your attention. If false, this is just computing a similarity score.
"""

def __init__(self, scaling_factor: int, normalize: bool = True) -> None:
def __init__(self, scaling_factor: Optional[int] = None, normalize: bool = True) -> None:
super().__init__(normalize)
self.scaling_factor = scaling_factor

@overrides
def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor:
scores = torch.matmul(vector, matrix)
scores = scores / math.sqrt(self.scaling_factor)
scores = super()._forward_internal(vector, matrix)
scaling_factor = self.scaling_factor or matrix.size(-1)
scores = scores / math.sqrt(scaling_factor)
return scores
3 changes: 3 additions & 0 deletions allennlp/modules/matrix_attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
from allennlp.modules.matrix_attention.bilinear_matrix_attention import BilinearMatrixAttention
from allennlp.modules.matrix_attention.cosine_matrix_attention import CosineMatrixAttention
from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention
from allennlp.modules.matrix_attention.scaled_dot_product_matrix_attention import (
ScaledDotProductMatrixAttention,
)
from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ class DotProductMatrixAttention(MatrixAttention):

@overrides
def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor:
return matrix_1.bmm(matrix_2.transpose(2, 1))
return matrix_1.matmul(matrix_2.transpose(-1, -2))
2 changes: 1 addition & 1 deletion allennlp/modules/matrix_attention/matrix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class MatrixAttention(torch.nn.Module, Registrable):
`MatrixAttention` takes two matrices as input and returns a matrix of attentions.

We compute the similarity between each row in each matrix and return unnormalized similarity
scores. Because these scores are unnormalized, we don't take a mask as input; it's up to the
scores. Because these scores are unnormalized, we don't take a mask as input; it's up to the
caller to deal with masking properly when this output is used.

Input:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import math

import torch
from overrides import overrides

from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention


@MatrixAttention.register("scaled_dot_product")
class ScaledDotProductMatrixAttention(DotProductMatrixAttention):
"""
Computes attention between every entry in matrix_1 with every entry in matrix_2 using a dot
product. Scales the result by the size of the embeddings.

Registered as a `MatrixAttention` with name "scaled_dot_product".
"""

@overrides
def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor:
return super().forward(matrix_1, matrix_2) / math.sqrt(matrix_1.size(-1))
17 changes: 5 additions & 12 deletions allennlp/modules/transformer/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from allennlp.common import FromParams
from allennlp.common.checks import ConfigurationError
from allennlp.modules.attention import Attention
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.modules.transformer.transformer_module import TransformerModule
from allennlp.modules.transformer.util import apply_mask, FloatT, IntT, BoolT

Expand Down Expand Up @@ -51,7 +51,7 @@ class AttentionModule(TransformerModule, FromParams):
scoring_func: `str` (default = `scaled_dot_product`)
The name of the attention-calculating function to be used.
Eg. `additive`, `linear`, etc. For a complete list, please check
:mod:`allennlp.modules.attention.attention`.
:mod:`allennlp.modules.matrix_attention.matrix_attention`.
output_linear: `bool` (default = `False`)
Whether to add an additional output linear layer at the end.
dropout: `float` (default = `0.0`)
Expand Down Expand Up @@ -113,12 +113,7 @@ def __init__(
self.output = torch.nn.Linear(self.all_head_size, hidden_size, bias=bias)

self.scoring_func = scoring_func
if self.scoring_func in ["additive", "linear", "bilinear"]:
self.attn = Attention.by_name(self.scoring_func)(hidden_size, hidden_size)
elif self.scoring_func == "scaled_dot_product":
self.attn = Attention.by_name(self.scoring_func)(self.attention_head_size, False)
else:
self.attn = Attention.by_name(self.scoring_func)()
self.attn = MatrixAttention.by_name(self.scoring_func)()

self.relative_attention_num_buckets = relative_attention_num_buckets

Expand Down Expand Up @@ -229,7 +224,7 @@ def _get_attention_probs(
past_key_states: Optional[torch.Tensor] = None,
**kwargs,
):
attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2))
attention_scores = self.attn(query_layer, key_layer)

position_bias = self._position_bias(
position_bias, seq_lengths, past_key_states, attention_scores
Expand Down Expand Up @@ -478,7 +473,7 @@ def __init__(
attention_head_size=key_value_proj_dim,
num_attention_heads=num_heads,
output_linear=True,
scoring_func="scaled_dot_product",
scoring_func="dot_product",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we change the default? The original transformer uses scaled_dot_product, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked you that on Slack. The original transformer uses scaled_, as per the paper, but in your implementation (which matches HF), the scaling factor is forced to 1, so it doesn't scale at all. I continue to be confused about this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed that. From what I can see, scaling factor is set to 1 for T5Attention, not for regular SelfAttention. I believe original T5 does the same. By default, we set a scaling factor for regular SelfAttention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. This is also only for T5Attention. I believe I left it the same by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I guess it's just T5 being extra?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-read section 2.1 of the T5 paper, and it doesn't mention this at all 🤷🏼‍♂️.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a whole lot of finicky little training details aren't mentioned in the 60+ pages paper. I think we were following the HF implementation.

dropout=dropout,
bias=False,
normalize_weights=normalize,
Expand All @@ -487,8 +482,6 @@ def __init__(
relative_attention_num_buckets=relative_attention_num_buckets,
)

self.attn = Attention.by_name(self.scoring_func)(scaling_factor=1, normalize=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AkshitaB, this is where the scaling factor is forced to 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for T5Attention.


def forward( # type: ignore
self,
hidden_states: torch.Tensor,
Expand Down
25 changes: 7 additions & 18 deletions allennlp/modules/transformer/bimodal_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from allennlp.common import FromParams
from allennlp.modules.attention import Attention
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.modules.transformer.transformer_module import TransformerModule
from allennlp.modules.transformer.util import apply_mask

Expand Down Expand Up @@ -44,7 +44,8 @@ class BiModalAttention(TransformerModule, FromParams):
The name of the attention-calculating function to be used for the first modality.
scoring_func2 : `str` (default = `scaled_dot_product`)
The name of the attention-calculating function to be used for the second modality.
Eg. `additive`, `linear`, etc. For a complete list, please check :mod:`allennlp.modules.attention`.
Eg. `dot_product`, `linear`, etc. For a complete list, please check
:mod:`allennlp.modules.matrix_attention`.
"""

def __init__(
Expand Down Expand Up @@ -79,13 +80,7 @@ def __init__(
self.value1 = torch.nn.Linear(hidden_size1, self.all_head_size)

self.scoring_func1 = scoring_func1
if self.scoring_func1 in ["additive", "linear", "bilinear"]:
self.attn1 = Attention.by_name(self.scoring_func1)(hidden_size1, hidden_size1)
elif self.scoring_func1 == "scaled_dot_product":
self.attn1 = Attention.by_name(self.scoring_func1)(self.attention_head_size, False)
else:
self.attn1 = Attention.by_name(self.scoring_func1)()

self.attn1 = MatrixAttention.by_name(self.scoring_func1)()
self.dropout1 = torch.nn.Dropout(dropout1)

# Second modality:
Expand All @@ -95,13 +90,7 @@ def __init__(
self.value2 = torch.nn.Linear(hidden_size2, self.all_head_size)

self.scoring_func2 = scoring_func2
if self.scoring_func2 in ["additive", "linear", "bilinear"]:
self.attn2 = Attention.by_name(self.scoring_func2)(hidden_size2, hidden_size2)
elif self.scoring_func2 == "scaled_dot_product":
self.attn2 = Attention.by_name(self.scoring_func2)(self.attention_head_size, False)
else:
self.attn2 = Attention.by_name(self.scoring_func2)()

self.attn2 = MatrixAttention.by_name(self.scoring_func2)()
self.dropout2 = torch.nn.Dropout(dropout2)

def _transpose_for_scores(self, x):
Expand Down Expand Up @@ -164,7 +153,7 @@ def forward(
value_layer2 = self._transpose_for_scores(mixed_value_layer2)

# Conditioning the second modality on the first one.
attention_scores1 = self.attn1(query_layer2, key_layer1.transpose(-1, -2))
attention_scores1 = self.attn1(query_layer2, key_layer1)
if attention_mask1 is not None:
attention_scores1 = apply_mask(attention_scores1, attention_mask1)
if co_attention_mask is not None:
Expand All @@ -182,7 +171,7 @@ def forward(
context_layer1 = context_layer1.view(*new_context_layer_shape1)

# Conditioning the first modality on the second one.
attention_scores2 = self.attn2(query_layer1, key_layer2.transpose(-1, -2))
attention_scores2 = self.attn2(query_layer1, key_layer2)
# we can comment this line for single flow.
if attention_mask2 is not None:
attention_scores2 = apply_mask(attention_scores2, attention_mask2)
Expand Down
21 changes: 21 additions & 0 deletions tests/modules/attention/attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch

from allennlp.modules import Attention
from allennlp.modules.attention import BilinearAttention, AdditiveAttention, LinearAttention


@pytest.mark.parametrize("attention_type", Attention.list_available())
def test_all_attention_works_the_same(attention_type: str):
module_cls = Attention.by_name(attention_type)

vector = torch.FloatTensor([[-7, -8, -9]])
matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]]])

if module_cls in {BilinearAttention, AdditiveAttention, LinearAttention}:
module = module_cls(vector.size(-1), matrix.size(-1))
else:
module = module_cls()

output = module(vector, matrix)
assert tuple(output.size()) == (1, 2)
4 changes: 2 additions & 2 deletions tests/modules/attention/dot_product_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_can_init_dot(self):
isinstance(legacy_attention, DotProductAttention)

def test_dot_product_similarity(self):
linear = DotProductAttention(normalize=False)
output = linear(
attn = DotProductAttention(normalize=False)
output = attn(
torch.FloatTensor([[0, 0, 0], [1, 1, 1]]),
torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]),
)
Expand Down
10 changes: 2 additions & 8 deletions tests/modules/attention/scaled_dot_product_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@ def test_can_init_scaled_dot(self):
def test_scaled_dot_product_similarity(self):
attn = ScaledDotProductAttention(9, normalize=False)
vector = torch.FloatTensor([[0, 0, 0], [1, 1, 1]])
matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]).transpose(
-1, -2
)
matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
output = attn(vector, matrix)

assert_almost_equal(
output.numpy(),
numpy.array([[[0.0, 0.0], [2.0, 5.0]], [[0.0, 0.0], [8.0, 11.0]]]),
decimal=2,
)
assert_almost_equal(output.numpy(), numpy.array([[0.0, 0.0], [8.0, 11.0]]), decimal=2)
21 changes: 21 additions & 0 deletions tests/modules/matrix_attention/matrix_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch

from allennlp.modules import MatrixAttention
from allennlp.modules.matrix_attention import BilinearMatrixAttention, LinearMatrixAttention


@pytest.mark.parametrize("attention_type", MatrixAttention.list_available())
def test_all_attention_works_the_same(attention_type: str):
module_cls = MatrixAttention.by_name(attention_type)

matrix1 = torch.FloatTensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]])
matrix2 = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]]])

if module_cls in {BilinearMatrixAttention, LinearMatrixAttention}:
module = module_cls(matrix1.size(-1), matrix2.size(-1))
else:
module = module_cls()

output = module(matrix1, matrix2)
assert tuple(output.size()) == (1, 3, 2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import math

import torch
from numpy.testing import assert_almost_equal
import numpy

from allennlp.common import Params
from allennlp.common.testing.test_case import AllenNlpTestCase
from allennlp.modules.matrix_attention import ScaledDotProductMatrixAttention
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention


class TestScaledDotProductMatrixAttention(AllenNlpTestCase):
def test_can_init_dot(self):
legacy_attention = MatrixAttention.from_params(Params({"type": "scaled_dot_product"}))
isinstance(legacy_attention, ScaledDotProductMatrixAttention)

def test_dot_product_similarity(self):
# example use case: a batch of size 2,
# with a time element component (e.g. sentences of length 2) each word is a vector of length 3.
# it is comparing this with another input of the same type
output = ScaledDotProductMatrixAttention()(
torch.FloatTensor([[[0, 0, 0], [4, 5, 6]], [[-7, -8, -9], [10, 11, 12]]]),
torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]),
)

# for the first batch there is
# no correlation between the first words of the input matrix
# but perfect correlation for the second word
# for the second batch there is
# negative correlation for the first words
# a correlation for the second word
assert_almost_equal(
output.numpy(),
numpy.array([[[0, 0], [32, 77]], [[-194, -266], [266, 365]]]) / math.sqrt(3),
decimal=2,
)