Skip to content

Commit 5c80ed2

Browse files
Enable Flash attention in Gemma (keras-team#2064)
* add flash attention to gemma * update attention mask * code reformat * use flash attention detection from utils * add gemmma flash attention * enable only in jax backend * update jax version * Update requirements-jax-cuda.txt * update jax version in requirements * update to python 3.10 * add quotes on python version * force jax to be 0.5.0 * check if dot product attention is supported * update python version * update python version * unpin stable * change back
1 parent ebc56b4 commit 5c80ed2

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

CONTRIBUTING.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ development environment and run the unit tests. This is covered in section
5858
### Step 3. Create a pull request
5959

6060
Once the change is ready, open a pull request from your branch in your fork to
61-
the master branch in
61+
the master branch in
6262
[keras-team/keras-hub](https://github.com/keras-team/keras-hub).
6363

6464
### Step 4. Sign the Contributor License Agreement
@@ -114,13 +114,13 @@ environement supports all backends without cuda, and each backend environement
114114
has cuda support.
115115

116116
```shell
117-
conda create -y -n keras-hub-cpu python=3.10
117+
conda create -y -n keras-hub-cpu python=3.9
118118
conda activate keras-hub-cpu
119119
pip install -r requirements.txt # install deps
120120
pip install -e . # install keras-hub
121121

122122
for backend in "jax" "torch" "tensorflow"; do
123-
conda create -y -n keras-hub-${backend} python=3.10
123+
conda create -y -n keras-hub-${backend} python=3.9
124124
conda activate keras-hub-${backend}
125125
pip install -r requirements-${backend}-cuda.txt # install deps
126126
pip install -e . # install keras-hub

keras_hub/src/models/gemma/gemma_attention.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
66
from keras_hub.src.utils.keras_utils import clone_initializer
7+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
78

89

910
class CachedGemmaAttention(keras.layers.Layer):
@@ -117,6 +118,36 @@ def _compute_attention(
117118
query_normalization = 1 / np.sqrt(
118119
self.hidden_dim // self.num_query_heads
119120
)
121+
use_dot_product_attention = not (
122+
self.dropout > 0.0 or (len(q.shape) != 4)
123+
)
124+
if has_flash_attention_support() and use_dot_product_attention:
125+
if self.dropout > 0.0:
126+
raise ValueError(
127+
"Flash attention does not support dropout. "
128+
"Please set `dropout` to 0.0."
129+
)
130+
if attention_mask is not None:
131+
while len(attention_mask.shape) < 4:
132+
attention_mask = ops.expand_dims(
133+
attention_mask, axis=1
134+
) # Add dimension for num_heads
135+
if attention_mask.shape[1] != self.num_query_heads:
136+
attention_mask = ops.tile(
137+
attention_mask, [1, self.num_query_heads, 1, 1]
138+
)
139+
140+
attention_output = ops.dot_product_attention(
141+
query=q,
142+
key=k,
143+
value=v,
144+
bias=None,
145+
mask=attention_mask,
146+
scale=query_normalization,
147+
is_causal=True,
148+
flash_attention=True,
149+
)
150+
return attention_output
120151

121152
q *= ops.cast(query_normalization, dtype=q.dtype)
122153
q_shape = ops.shape(q)
@@ -131,8 +162,8 @@ def _compute_attention(
131162
)
132163
b, q_len, _, _, h = ops.shape(q)
133164

165+
# Fallback to standard attention if flash attention is disabled
134166
attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
135-
136167
if self.logit_soft_cap is not None:
137168
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
138169
attention_logits = ops.multiply(

keras_hub/src/utils/keras_utils.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,19 @@ def standardize_data_format(data_format):
5656

5757

5858
def has_flash_attention_support():
59-
if hasattr(keras.config, "is_flash_attention_enabled"):
59+
if (
60+
hasattr(keras.config, "is_flash_attention_enabled")
61+
and keras.config.backend() == "jax"
62+
):
63+
try:
64+
from jax.nn import dot_product_attention as dot_product_attention
65+
except ImportError:
66+
logging.warning(
67+
"Flash attention is not supported in your current JAX version. "
68+
"Please update it by following the official guide: "
69+
"https://jax.readthedocs.io/en/latest/installation.html"
70+
)
71+
return False
6072
return True
6173
else:
6274
return False

0 commit comments

Comments
 (0)