Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inputs_k and inputs_v args to attention layer #3379

Closed
wants to merge 1 commit into from

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Sep 28, 2023

Currently, MultiHeadDotProductAttention layer's call method signature is MultiHeadDotProductAttention.__call__(inputs_q, inputs_kv, mask=None, deterministic=None). As discussed in #1737, there are some cases where passing in separate values for the key and values is desired, which isn't possible with the current API. This PR adds two more arguments, inputs_k and inputs_v to the call method signature and sets the method signature to the following: MultiHeadDotProductAttention.__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None). Note that the inputs_kv, mask and deterministic args are now keyword arguments.

  • if inputs_k and inputs_v are None, then they will both copy the value of inputs_q (i.e. self attention)
  • if inputs_v is None, it will copy the value of inputs_k (same behavior as the previous API, i.e. module.apply(inputs_q=query, inputs_k=key_value, ...) is equivalent to module.apply(inputs_q=query, inputs_kv=key_value, ...))
  • if inputs_kv is not None, both inputs_k and inputs_v will copy the value of inputs_kv

Users can still use inputs_kv but a DeprecationWarning will be raised and inputs_kv will be removed in the future.
Since self attention can be done using this new API, the SelfAttention layer will also raise a DeprecationWarning and will be removed in the future.

Check out #3389 to see examples of how to port your code over to the new API.

@chiamp chiamp self-assigned this Sep 28, 2023
@chiamp chiamp changed the title split inputs_kv arg in attention layer Add inputs_k and inputs_v args to attention layer Sep 28, 2023
@codecov-commenter
Copy link

codecov-commenter commented Sep 28, 2023

Codecov Report

Merging #3379 (1d41190) into main (f20aed4) will increase coverage by 0.02%.
Report is 2 commits behind head on main.
The diff coverage is 90.90%.

@@            Coverage Diff             @@
##             main    #3379      +/-   ##
==========================================
+ Coverage   83.60%   83.62%   +0.02%     
==========================================
  Files          56       56              
  Lines        6746     6767      +21     
==========================================
+ Hits         5640     5659      +19     
- Misses       1106     1108       +2     
Files Coverage Δ
flax/linen/attention.py 94.19% <90.90%> (-0.59%) ⬇️

... and 1 file with indirect coverage changes

@chiamp chiamp force-pushed the attention branch 8 times, most recently from 2db0753 to dc02493 Compare October 4, 2023 22:13
flax/linen/attention.py Outdated Show resolved Hide resolved
@cgarciae
Copy link
Collaborator

cgarciae commented Oct 5, 2023

Left a comment. Otherwise, looks good!

copybara-service bot pushed a commit that referenced this pull request Oct 11, 2023
--
f6a222c by Marcus Chiam <marcuschiam@google.com>:

split inputs_kv arg in attention layer

COPYBARA_INTEGRATE_REVIEW=#3379 from chiamp:attention f6a222c
PiperOrigin-RevId: 572671273
@chiamp
Copy link
Collaborator Author

chiamp commented Oct 11, 2023

Closing after this commit landed.

@chiamp chiamp closed this Oct 11, 2023
@chiamp chiamp deleted the attention branch October 27, 2023 21:11
8bitmp3 pushed a commit to 8bitmp3/flax that referenced this pull request Nov 16, 2023
--
f6a222c by Marcus Chiam <marcuschiam@google.com>:

split inputs_kv arg in attention layer

COPYBARA_INTEGRATE_REVIEW=google#3379 from chiamp:attention f6a222c
PiperOrigin-RevId: 572671273
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants