This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Make sure that all attention works the same #5360
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
d073da1
Adds a test to make sure that all attention works the same
dirkgr b28a3fc
Merge branch 'main' into AttentionToAttention
dirkgr 23f0c84
Autodetect scaling factor
dirkgr 6b24f80
Refactors attention so that scaled dot product attention lives where …
dirkgr f83ac8d
Formatting
dirkgr 1d1a76a
Merge branch 'AttentionToAttention' of https://github.com/allenai/all…
dirkgr 4ee5098
Formatting
dirkgr 23c2ee9
Changelog
dirkgr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
21 changes: 21 additions & 0 deletions
21
allennlp/modules/matrix_attention/scaled_dot_product_matrix_attention.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import math | ||
|
||
import torch | ||
from overrides import overrides | ||
|
||
from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention | ||
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention | ||
|
||
|
||
@MatrixAttention.register("scaled_dot_product") | ||
class ScaledDotProductMatrixAttention(DotProductMatrixAttention): | ||
""" | ||
Computes attention between every entry in matrix_1 with every entry in matrix_2 using a dot | ||
product. Scales the result by the size of the embeddings. | ||
|
||
Registered as a `MatrixAttention` with name "scaled_dot_product". | ||
""" | ||
|
||
@overrides | ||
def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: | ||
return super().forward(matrix_1, matrix_2) / math.sqrt(matrix_1.size(-1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
from allennlp.common import FromParams | ||
from allennlp.common.checks import ConfigurationError | ||
from allennlp.modules.attention import Attention | ||
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention | ||
from allennlp.modules.transformer.transformer_module import TransformerModule | ||
from allennlp.modules.transformer.util import apply_mask, FloatT, IntT, BoolT | ||
|
||
|
@@ -51,7 +51,7 @@ class AttentionModule(TransformerModule, FromParams): | |
scoring_func: `str` (default = `scaled_dot_product`) | ||
The name of the attention-calculating function to be used. | ||
Eg. `additive`, `linear`, etc. For a complete list, please check | ||
:mod:`allennlp.modules.attention.attention`. | ||
:mod:`allennlp.modules.matrix_attention.matrix_attention`. | ||
output_linear: `bool` (default = `False`) | ||
Whether to add an additional output linear layer at the end. | ||
dropout: `float` (default = `0.0`) | ||
|
@@ -113,12 +113,7 @@ def __init__( | |
self.output = torch.nn.Linear(self.all_head_size, hidden_size, bias=bias) | ||
|
||
self.scoring_func = scoring_func | ||
if self.scoring_func in ["additive", "linear", "bilinear"]: | ||
self.attn = Attention.by_name(self.scoring_func)(hidden_size, hidden_size) | ||
elif self.scoring_func == "scaled_dot_product": | ||
self.attn = Attention.by_name(self.scoring_func)(self.attention_head_size, False) | ||
else: | ||
self.attn = Attention.by_name(self.scoring_func)() | ||
self.attn = MatrixAttention.by_name(self.scoring_func)() | ||
|
||
self.relative_attention_num_buckets = relative_attention_num_buckets | ||
|
||
|
@@ -229,7 +224,7 @@ def _get_attention_probs( | |
past_key_states: Optional[torch.Tensor] = None, | ||
**kwargs, | ||
): | ||
attention_scores = self.attn(query_layer, key_layer.transpose(-1, -2)) | ||
attention_scores = self.attn(query_layer, key_layer) | ||
|
||
position_bias = self._position_bias( | ||
position_bias, seq_lengths, past_key_states, attention_scores | ||
|
@@ -478,7 +473,7 @@ def __init__( | |
attention_head_size=key_value_proj_dim, | ||
num_attention_heads=num_heads, | ||
output_linear=True, | ||
scoring_func="scaled_dot_product", | ||
scoring_func="dot_product", | ||
dropout=dropout, | ||
bias=False, | ||
normalize_weights=normalize, | ||
|
@@ -487,8 +482,6 @@ def __init__( | |
relative_attention_num_buckets=relative_attention_num_buckets, | ||
) | ||
|
||
self.attn = Attention.by_name(self.scoring_func)(scaling_factor=1, normalize=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AkshitaB, this is where the scaling factor is forced to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for |
||
|
||
def forward( # type: ignore | ||
self, | ||
hidden_states: torch.Tensor, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
import torch | ||
|
||
from allennlp.modules import Attention | ||
from allennlp.modules.attention import BilinearAttention, AdditiveAttention, LinearAttention | ||
|
||
|
||
@pytest.mark.parametrize("attention_type", Attention.list_available()) | ||
def test_all_attention_works_the_same(attention_type: str): | ||
module_cls = Attention.by_name(attention_type) | ||
|
||
vector = torch.FloatTensor([[-7, -8, -9]]) | ||
matrix = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]]]) | ||
|
||
if module_cls in {BilinearAttention, AdditiveAttention, LinearAttention}: | ||
module = module_cls(vector.size(-1), matrix.size(-1)) | ||
else: | ||
module = module_cls() | ||
|
||
output = module(vector, matrix) | ||
assert tuple(output.size()) == (1, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
import torch | ||
|
||
from allennlp.modules import MatrixAttention | ||
from allennlp.modules.matrix_attention import BilinearMatrixAttention, LinearMatrixAttention | ||
|
||
|
||
@pytest.mark.parametrize("attention_type", MatrixAttention.list_available()) | ||
def test_all_attention_works_the_same(attention_type: str): | ||
module_cls = MatrixAttention.by_name(attention_type) | ||
|
||
matrix1 = torch.FloatTensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) | ||
matrix2 = torch.FloatTensor([[[1, 2, 3], [4, 5, 6]]]) | ||
|
||
if module_cls in {BilinearMatrixAttention, LinearMatrixAttention}: | ||
module = module_cls(matrix1.size(-1), matrix2.size(-1)) | ||
else: | ||
module = module_cls() | ||
|
||
output = module(matrix1, matrix2) | ||
assert tuple(output.size()) == (1, 3, 2) |
37 changes: 37 additions & 0 deletions
37
tests/modules/matrix_attention/scaled_dot_product_matrix_attention_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import math | ||
|
||
import torch | ||
from numpy.testing import assert_almost_equal | ||
import numpy | ||
|
||
from allennlp.common import Params | ||
from allennlp.common.testing.test_case import AllenNlpTestCase | ||
from allennlp.modules.matrix_attention import ScaledDotProductMatrixAttention | ||
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention | ||
|
||
|
||
class TestScaledDotProductMatrixAttention(AllenNlpTestCase): | ||
def test_can_init_dot(self): | ||
legacy_attention = MatrixAttention.from_params(Params({"type": "scaled_dot_product"})) | ||
isinstance(legacy_attention, ScaledDotProductMatrixAttention) | ||
|
||
def test_dot_product_similarity(self): | ||
# example use case: a batch of size 2, | ||
# with a time element component (e.g. sentences of length 2) each word is a vector of length 3. | ||
# it is comparing this with another input of the same type | ||
output = ScaledDotProductMatrixAttention()( | ||
torch.FloatTensor([[[0, 0, 0], [4, 5, 6]], [[-7, -8, -9], [10, 11, 12]]]), | ||
torch.FloatTensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), | ||
) | ||
|
||
# for the first batch there is | ||
# no correlation between the first words of the input matrix | ||
# but perfect correlation for the second word | ||
# for the second batch there is | ||
# negative correlation for the first words | ||
# a correlation for the second word | ||
assert_almost_equal( | ||
output.numpy(), | ||
numpy.array([[[0, 0], [32, 77]], [[-194, -266], [266, 365]]]) / math.sqrt(3), | ||
decimal=2, | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we change the default? The original transformer uses
scaled_dot_product
, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I asked you that on Slack. The original transformer uses
scaled_
, as per the paper, but in your implementation (which matches HF), the scaling factor is forced to 1, so it doesn't scale at all. I continue to be confused about this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I missed that. From what I can see, scaling factor is set to 1 for
T5Attention
, not for regularSelfAttention
. I believe original T5 does the same. By default, we set a scaling factor for regularSelfAttention
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. This is also only for
T5Attention
. I believe I left it the same by default.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I guess it's just T5 being extra?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I re-read section 2.1 of the T5 paper, and it doesn't mention this at all 🤷🏼♂️.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, a whole lot of finicky little training details aren't mentioned in the 60+ pages paper. I think we were following the HF implementation.