Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupedQueryAttention layer #18488

Merged
merged 38 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e6bc383
add: gqa barebone
awsaf49 Sep 25, 2023
04b6fe7
update: avoid abbr. in signature
awsaf49 Sep 26, 2023
85717ea
update: use fully-spelled snake_case
awsaf49 Sep 26, 2023
18e319c
add: initializers, regularizers, constraint
awsaf49 Oct 14, 2023
2bc6e46
update: mask to attention_mask
awsaf49 Oct 14, 2023
004dea4
add: `dropout`
awsaf49 Oct 15, 2023
acd8a7e
add: `softmax` from `ops`
awsaf49 Oct 15, 2023
b0d45db
update: use softmax for masking
awsaf49 Oct 15, 2023
cbfc9d2
add: import in `__init__`
awsaf49 Oct 15, 2023
660429a
update: filename
awsaf49 Oct 15, 2023
df13321
add: `compute_output_shape`
awsaf49 Oct 15, 2023
1f93104
add: `query-key-value` & `causal` mask
awsaf49 Oct 15, 2023
7c6f063
update: use `EinsumDense` and `einsum`
awsaf49 Oct 15, 2023
7d4b8c9
update: remove `Dense` import
awsaf49 Oct 15, 2023
c6b73cc
fix: `__init__` in layer for `isort`
awsaf49 Oct 15, 2023
d34de7b
update: docstring
awsaf49 Oct 15, 2023
321914d
add: simple test
awsaf49 Oct 15, 2023
b361be5
update: `support_masking` False in test
awsaf49 Oct 15, 2023
c0010e3
fix: error due to query & key-value seq_len mismatch
awsaf49 Oct 15, 2023
5ea5379
Revert "update: `support_masking` False in test"
awsaf49 Oct 15, 2023
6cbfd39
update: `use_bias` = True as default
awsaf49 Oct 15, 2023
af9051f
add: `support_masking`
awsaf49 Oct 15, 2023
8b1489d
add: more tests
awsaf49 Oct 15, 2023
334c776
update: code format with `black`
awsaf49 Oct 15, 2023
81ca49c
remove: high dim attention test
awsaf49 Oct 15, 2023
710a661
fix: remove undefined arg `num_head`
awsaf49 Oct 15, 2023
2a92866
add: shape mismatch test
awsaf49 Oct 15, 2023
30b88e8
add: initializer test
awsaf49 Oct 15, 2023
07ebe3b
update: `softmax` for mask propagation
awsaf49 Oct 15, 2023
5514345
update: code format
awsaf49 Oct 15, 2023
e70dbb1
add: mask propagation test
awsaf49 Oct 16, 2023
227325a
add: masking test
awsaf49 Oct 16, 2023
35f7e80
add: correctness test
awsaf49 Oct 16, 2023
d018a7e
update: output shape test for mqa, gqa & mha
awsaf49 Oct 16, 2023
a3b89dc
fix: code format for `isort`
awsaf49 Oct 16, 2023
51219fc
add: divisible error check
awsaf49 Oct 19, 2023
75b309c
add: shape of `attention_scores`
awsaf49 Oct 19, 2023
1535b7e
update: different letters for query and key-value heads
awsaf49 Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.layers.activations.softmax import Softmax
from keras.layers.attention.additive_attention import AdditiveAttention
from keras.layers.attention.attention import Attention
from keras.layers.attention.grouped_query_attention import GroupedQueryAttention
from keras.layers.attention.multi_head_attention import MultiHeadAttention
from keras.layers.convolutional.conv1d import Conv1D
from keras.layers.convolutional.conv1d_transpose import Conv1DTranspose
Expand Down
Loading