-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Make sure that all attention works the same #5360
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a question about the default scoring function.
@@ -478,7 +473,7 @@ def __init__( | |||
attention_head_size=key_value_proj_dim, | |||
num_attention_heads=num_heads, | |||
output_linear=True, | |||
scoring_func="scaled_dot_product", | |||
scoring_func="dot_product", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we change the default? The original transformer uses scaled_dot_product
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I asked you that on Slack. The original transformer uses scaled_
, as per the paper, but in your implementation (which matches HF), the scaling factor is forced to 1, so it doesn't scale at all. I continue to be confused about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I missed that. From what I can see, scaling factor is set to 1 for T5Attention
, not for regular SelfAttention
. I believe original T5 does the same. By default, we set a scaling factor for regular SelfAttention
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. This is also only for T5Attention
. I believe I left it the same by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I guess it's just T5 being extra?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I re-read section 2.1 of the T5 paper, and it doesn't mention this at all 🤷🏼♂️.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, a whole lot of finicky little training details aren't mentioned in the 60+ pages paper. I think we were following the HF implementation.
@@ -487,8 +482,6 @@ def __init__( | |||
relative_attention_num_buckets=relative_attention_num_buckets, | |||
) | |||
|
|||
self.attn = Attention.by_name(self.scoring_func)(scaling_factor=1, normalize=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AkshitaB, this is where the scaling factor is forced to 1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for T5Attention
.
Addresses #5345.
ScaledDotProductMatrixAttention
, and converted the transformer toolkit to use itAttention
andMatrixAttention
implementations are interchangeableScaledDotProductAttention
to match the otherAttention
classes.