Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Make sure that all attention works the same #5360

Merged
merged 8 commits into from
Aug 17, 2021
Merged

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Aug 16, 2021

Addresses #5345.

  • Added ScaledDotProductMatrixAttention, and converted the transformer toolkit to use it
  • Added tests to ensure that all Attention and MatrixAttention implementations are interchangeable
  • Fixed the signature of ScaledDotProductAttention to match the other Attention classes.

@dirkgr dirkgr linked an issue Aug 16, 2021 that may be closed by this pull request
10 tasks
@dirkgr dirkgr marked this pull request as ready for review August 17, 2021 18:46
@dirkgr dirkgr requested a review from AkshitaB August 17, 2021 18:46
Copy link
Contributor

@AkshitaB AkshitaB left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a question about the default scoring function.

@@ -478,7 +473,7 @@ def __init__(
attention_head_size=key_value_proj_dim,
num_attention_heads=num_heads,
output_linear=True,
scoring_func="scaled_dot_product",
scoring_func="dot_product",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we change the default? The original transformer uses scaled_dot_product, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked you that on Slack. The original transformer uses scaled_, as per the paper, but in your implementation (which matches HF), the scaling factor is forced to 1, so it doesn't scale at all. I continue to be confused about this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed that. From what I can see, scaling factor is set to 1 for T5Attention, not for regular SelfAttention. I believe original T5 does the same. By default, we set a scaling factor for regular SelfAttention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. This is also only for T5Attention. I believe I left it the same by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I guess it's just T5 being extra?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-read section 2.1 of the T5 paper, and it doesn't mention this at all 🤷🏼‍♂️.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a whole lot of finicky little training details aren't mentioned in the 60+ pages paper. I think we were following the HF implementation.

@@ -487,8 +482,6 @@ def __init__(
relative_attention_num_buckets=relative_attention_num_buckets,
)

self.attn = Attention.by_name(self.scoring_func)(scaling_factor=1, normalize=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AkshitaB, this is where the scaling factor is forced to 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for T5Attention.

@dirkgr dirkgr merged commit d45a2da into main Aug 17, 2021
@dirkgr dirkgr deleted the AttentionToAttention branch August 17, 2021 23:51
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Scaled Dot Product Attention matmul error
2 participants