Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long T5 #179

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open

Long T5 #179

wants to merge 10 commits into from

Conversation

HaokunLiu
Copy link

Based on @AkshitaB 's work (#149), this PR extends Longformer to T5. It also adds a test to check if the Longformer T5 produces the same output as the standard T5 on short input texts, as suggested by @ibeltagy in this comment

A quick thing about code style: I'm not sure if this repo has selected any formatter previously. I didn't find dev-requirements.txt. So I continue to use the black formatter in my default setting. It automatically re-formats the file whenever I save it. You may notice changes like ' -> ", or breaking a long line into multiple lines. I hope it doesn't bother you too much.



class LongformerSelfAttention(nn.Module):
def __init__(self, config, layer_id):
def __init__(self, config, layer_id, bias=True, attention_dim_scale=True):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T5 attention module is slightly different from conventional ones. It doesn't have bias, nor does it scale the attention score according to attention head dimension before softmax. See this list for more details.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the default option, bias=True, attention_dim_scale=True. This should just fall back to regular self-attention.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add your comment to the code.

selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
# concat to attn_weights
# (bsz, seq_len, num_heads, extra attention count + 2*window+1)
# (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed annotation to be consistent with related annotations below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@@ -78,28 +89,38 @@ def __init__(self, config, layer_id):
self.attention_dilation = config.attention_dilation[self.layer_id]
self.attention_mode = config.attention_mode
self.autoregressive = config.autoregressive

if hasattr(config, "relative_attention_num_buckets") and layer_id == 0:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In T5, the position bias is shared across layers. This is done by letting the first layer compute the position bias, then pass it on to the remaining layers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Please write this comment in the code for more readablity.

if output_attentions:
outputs = outputs + (attn_weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is equivalent to the old output form, when self.has_relative_attention_bias=False

return outputs


def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was considering moving this to longformer_encoder_decoder, but that will lead to cycle import, so this has to be here.

layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i)


class LongformerT5Config(T5Config):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see, we are getting many highly-similar config classes as we extending to other transformer models. If you like, we can simplify this by using Mixin. It will be like having another Mixin class containing all the longformer specific settings, and the LongformerT5Config class will inherit both the Mixin class and T5Config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)

)
self.output = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

def forward(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative I considered was to let this class inherit LongformerSelfAttention. But eventually, I decided not to do so. The interfaces of the two classes are quite different. What we have here, i.e., making LongformerSelfAttention a member of the LongformerSelfAttentionForT5, is probably less confusing than the althernative.

attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)

if position_bias is None and self.has_relative_attention_bias:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the sliding window already has put the attention score in the form of [q_(i) * k_(i-w), q_(i) * k_(i-w+1), ..., q_(i) * k_(i), ... , q_(i) * k_(i+w)] the relative position is simply arange

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move this comment to the code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe also move this block of code to a separate function

perm_global_position_bias = attn_weights.new_zeros(
bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads
) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads)
if extra_attention_mask is not None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Global position bias is a bit more complex. We first get the memory position from extra_attention_mask_nonzeros, then compute the query position using arrange. Their diff is the relative position. But this "sparse" one vector for each global token in the batch. So we later put it back into the shape of (bsz, max_num_extra_indices_per_batch, ...) using the index information from selection_padding_mask_nonzeros

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't review this part yet.

Copy link
Collaborator

@ibeltagy ibeltagy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thank you.
I left a few small comments. I didn't review the global attention part yet, will do later, maybe today.

@@ -78,28 +89,38 @@ def __init__(self, config, layer_id):
self.attention_dilation = config.attention_dilation[self.layer_id]
self.attention_mode = config.attention_mode
self.autoregressive = config.autoregressive

if hasattr(config, "relative_attention_num_buckets") and layer_id == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Please write this comment in the code for more readablity.



class LongformerSelfAttention(nn.Module):
def __init__(self, config, layer_id):
def __init__(self, config, layer_id, bias=True, attention_dim_scale=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add your comment to the code.

selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
# concat to attn_weights
# (bsz, seq_len, num_heads, extra attention count + 2*window+1)
# (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)

if position_bias is None and self.has_relative_attention_bias:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move this comment to the code.

attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)

if position_bias is None and self.has_relative_attention_bias:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe also move this block of code to a separate function

perm_global_position_bias = attn_weights.new_zeros(
bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads
) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads)
if extra_attention_mask is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't review this part yet.

base_model_name_or_path="t5-small",
)
self._run_test(
INPUT_TEXT="It begins with the Great Hungerer. It ends in utter darkeness.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:D

def test_outout(self):
self._run_test(
INPUT_TEXT="Hello world!",
long_model_name_or_path="/net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-4096",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if this test works without the local model. One way to do so is to call create_long_model in the text to convert t5 to long, then test it. It will make the test slower but easier to run.

layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i)


class LongformerT5Config(T5Config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)

@@ -5,14 +5,17 @@
import torch.nn.functional as F
from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv
from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
from longformer.sliding_chunks import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine that your dev env changed the file format. I know it doesn't change the code but I will feel more comfortable if you run a small test to make sure the new code produces the same output as the previous one for Longformer.

# in T5 attention_probs_dropout_prob is dropout_rate
config.attention_probs_dropout_prob = config.dropout_rate
config.attention_window = [attention_window] * config.num_hidden_layers
config.attention_dilation = [1] * config.num_hidden_layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when increasing the model length we probably want to increase number of relative position buckets as well config.relative_attention_num_buckets

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants