-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Long T5 #179
base: master
Are you sure you want to change the base?
Conversation
|
||
|
||
class LongformerSelfAttention(nn.Module): | ||
def __init__(self, config, layer_id): | ||
def __init__(self, config, layer_id, bias=True, attention_dim_scale=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
T5 attention module is slightly different from conventional ones. It doesn't have bias, nor does it scale the attention score according to attention head dimension before softmax. See this list for more details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the default option, bias=True
, attention_dim_scale=True
. This should just fall back to regular self-attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add your comment to the code.
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 | ||
# concat to attn_weights | ||
# (bsz, seq_len, num_heads, extra attention count + 2*window+1) | ||
# (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed annotation to be consistent with related annotations below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
@@ -78,28 +89,38 @@ def __init__(self, config, layer_id): | |||
self.attention_dilation = config.attention_dilation[self.layer_id] | |||
self.attention_mode = config.attention_mode | |||
self.autoregressive = config.autoregressive | |||
|
|||
if hasattr(config, "relative_attention_num_buckets") and layer_id == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In T5, the position bias is shared across layers. This is done by letting the first layer compute the position bias, then pass it on to the remaining layers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Please write this comment in the code for more readablity.
if output_attentions: | ||
outputs = outputs + (attn_weights,) | ||
if self.has_relative_attention_bias: | ||
outputs = outputs + (position_bias,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is equivalent to the old output form, when self.has_relative_attention_bias=False
longformer/longformer.py
Outdated
return outputs | ||
|
||
|
||
def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was considering moving this to longformer_encoder_decoder, but that will lead to cycle import, so this has to be here.
layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) | ||
|
||
|
||
class LongformerT5Config(T5Config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can see, we are getting many highly-similar config classes as we extending to other transformer models. If you like, we can simplify this by using Mixin. It will be like having another Mixin class containing all the longformer specific settings, and the LongformerT5Config class will inherit both the Mixin class and T5Config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)
) | ||
self.output = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An alternative I considered was to let this class inherit LongformerSelfAttention. But eventually, I decided not to do so. The interfaces of the two classes are quite different. What we have here, i.e., making LongformerSelfAttention a member of the LongformerSelfAttentionForT5, is probably less confusing than the althernative.
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) | ||
|
||
if position_bias is None and self.has_relative_attention_bias: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since the sliding window already has put the attention score in the form of [q_(i) * k_(i-w), q_(i) * k_(i-w+1), ..., q_(i) * k_(i), ... , q_(i) * k_(i+w)]
the relative position is simply arange
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please move this comment to the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Maybe also move this block of code to a separate function
perm_global_position_bias = attn_weights.new_zeros( | ||
bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads | ||
) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) | ||
if extra_attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Global position bias is a bit more complex. We first get the memory position from extra_attention_mask_nonzeros
, then compute the query position using arrange. Their diff is the relative position. But this "sparse" one vector for each global token in the batch. So we later put it back into the shape of (bsz, max_num_extra_indices_per_batch, ...)
using the index information from selection_padding_mask_nonzeros
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't review this part yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thank you.
I left a few small comments. I didn't review the global attention part yet, will do later, maybe today.
@@ -78,28 +89,38 @@ def __init__(self, config, layer_id): | |||
self.attention_dilation = config.attention_dilation[self.layer_id] | |||
self.attention_mode = config.attention_mode | |||
self.autoregressive = config.autoregressive | |||
|
|||
if hasattr(config, "relative_attention_num_buckets") and layer_id == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Please write this comment in the code for more readablity.
|
||
|
||
class LongformerSelfAttention(nn.Module): | ||
def __init__(self, config, layer_id): | ||
def __init__(self, config, layer_id, bias=True, attention_dim_scale=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add your comment to the code.
selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 | ||
# concat to attn_weights | ||
# (bsz, seq_len, num_heads, extra attention count + 2*window+1) | ||
# (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) | ||
|
||
if position_bias is None and self.has_relative_attention_bias: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please move this comment to the code.
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) | ||
|
||
if position_bias is None and self.has_relative_attention_bias: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Maybe also move this block of code to a separate function
perm_global_position_bias = attn_weights.new_zeros( | ||
bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads | ||
) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) | ||
if extra_attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't review this part yet.
tests/test_t5_short_sequence.py
Outdated
base_model_name_or_path="t5-small", | ||
) | ||
self._run_test( | ||
INPUT_TEXT="It begins with the Great Hungerer. It ends in utter darkeness.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:D
def test_outout(self): | ||
self._run_test( | ||
INPUT_TEXT="Hello world!", | ||
long_model_name_or_path="/net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-4096", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great if this test works without the local model. One way to do so is to call create_long_model
in the text to convert t5 to long, then test it. It will make the test slower but easier to run.
layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) | ||
|
||
|
||
class LongformerT5Config(T5Config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)
@@ -5,14 +5,17 @@ | |||
import torch.nn.functional as F | |||
from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations | |||
from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv | |||
from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv | |||
from longformer.sliding_chunks import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is fine that your dev env changed the file format. I know it doesn't change the code but I will feel more comfortable if you run a small test to make sure the new code produces the same output as the previous one for Longformer.
# in T5 attention_probs_dropout_prob is dropout_rate | ||
config.attention_probs_dropout_prob = config.dropout_rate | ||
config.attention_window = [attention_window] * config.num_hidden_layers | ||
config.attention_dilation = [1] * config.num_hidden_layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when increasing the model length we probably want to increase number of relative position buckets as well config.relative_attention_num_buckets
Based on @AkshitaB 's work (#149), this PR extends Longformer to T5. It also adds a test to check if the Longformer T5 produces the same output as the standard T5 on short input texts, as suggested by @ibeltagy in this comment
A quick thing about code style: I'm not sure if this repo has selected any formatter previously. I didn't find
dev-requirements.txt
. So I continue to use the black formatter in my default setting. It automatically re-formats the file whenever I save it. You may notice changes like'
->"
, or breaking a long line into multiple lines. I hope it doesn't bother you too much.