Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MultiHeadAttention Layer #7875

Merged
merged 25 commits into from
Aug 7, 2023
Merged

Conversation

pforderique
Copy link
Contributor

Implements the MultiHeadAttention layer from Keras attention layers.

NOTE:

  • This implementation does not support RaggedTensors yet.
  • The Softmax layer was changed to support masking. Let me know if having a separate class for this is better. Regardless, the current change should be non-breaking.

Depends on #7860.

Copy link
Member

@mattsoulanille mattsoulanille left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mattsoulanille
Copy link
Member

@Linchenn Please take a look when you get a chance. Thanks!

@pforderique pforderique enabled auto-merge (squash) August 7, 2023 17:22
Copy link
Collaborator

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pforderique pforderique merged commit 0fd462d into tensorflow:master Aug 7, 2023
@pforderique pforderique deleted the mha-impl branch August 7, 2023 21:50
pforderique added a commit that referenced this pull request Aug 7, 2023
* Implement position embedding

* Strip debug ops in jax conversion tests (#7889)

INTERNAL
This fixes an internal issue with jax tests. See cl/550054296.

* Update weights loading (#7872)

* Update weights loading

* fix tests

* remove

* fix

* fix comments

* fix lint

* Load python rules in tfjs-converter converters dir (#7892)

* Implement MultiHeadAttention Layer (#7875)

* Add spec for multi-head attention

* Add CachedMultiHeadAttention cache

* Fix typos

* Lint

* Add Transformer Decoder spec

* lint

* Add Einsum spec

* lint

* Remove unused type declaration

* Move helper functions outside EinsumDense class

* Implement Einsum Dense

* Address comments

* Implement MHA Layer

* Add masked softmax support

* Fix typo

* Check for undef and null

* Make buildFromSignature public

* Wrap softmax call in tf.tidy

* Implement position embedding

---------

Co-authored-by: Matthew Soulanille <msoulanille@google.com>
Co-authored-by: fengwuyao <131706622+fengwuyao@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants