Skip to content

Commit 3ab5189

Browse files
committed
Merge branch 'add_small_100_preset' of https://github.com/pkgoogle/keras-hub into add_small_100_preset
2 parents 948ea8f + 921e113 commit 3ab5189

10 files changed

+174
-11
lines changed

CONTRIBUTING.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ development environment and run the unit tests. This is covered in section
5858
### Step 3. Create a pull request
5959

6060
Once the change is ready, open a pull request from your branch in your fork to
61-
the master branch in
61+
the master branch in
6262
[keras-team/keras-hub](https://github.com/keras-team/keras-hub).
6363

6464
### Step 4. Sign the Contributor License Agreement
@@ -114,13 +114,13 @@ environement supports all backends without cuda, and each backend environement
114114
has cuda support.
115115

116116
```shell
117-
conda create -y -n keras-hub-cpu python=3.10
117+
conda create -y -n keras-hub-cpu python=3.9
118118
conda activate keras-hub-cpu
119119
pip install -r requirements.txt # install deps
120120
pip install -e . # install keras-hub
121121

122122
for backend in "jax" "torch" "tensorflow"; do
123-
conda create -y -n keras-hub-${backend} python=3.10
123+
conda create -y -n keras-hub-${backend} python=3.9
124124
conda activate keras-hub-${backend}
125125
pip install -r requirements-${backend}-cuda.txt # install deps
126126
pip install -e . # install keras-hub

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ To install the latest KerasHub release with Keras 3, simply run:
102102
pip install --upgrade keras-hub
103103
```
104104

105+
Our text tokenizers are based on TensorFlow Text. Hence, if you are using any
106+
model which has language as a modality, you will have to run:
107+
108+
```
109+
pip install --upgrade keras-hub[nlp]
110+
```
111+
105112
To install the latest nightly changes for both KerasHub and Keras, you can use
106113
our nightly package.
107114

keras_hub/src/models/backbone.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,23 @@ def save_to_preset(self, preset_dir):
186186
saver = get_preset_saver(preset_dir)
187187
saver.save_backbone(self)
188188

189+
def get_lora_target_names(self):
190+
"""Returns list of layer names which are to be LoRA-fied.
191+
192+
Subclasses can override this method if the names of layers to be
193+
LoRa-fied are different.
194+
"""
195+
return ["query_dense", "value_dense", "query", "value"]
196+
189197
def enable_lora(self, rank):
190198
"""Enable Lora on the backbone.
191199
192200
Calling this method will freeze all weights on the backbone,
193201
while enabling Lora on the query & value `EinsumDense` layers
194202
of the attention layers.
195203
"""
196-
target_names = ["query_dense", "value_dense", "query", "value"]
204+
target_names = self.get_lora_target_names()
205+
197206
self.trainable = True
198207
self._lora_enabled_layers = []
199208
self._lora_rank = rank

keras_hub/src/models/gemma/gemma_attention.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
66
from keras_hub.src.utils.keras_utils import clone_initializer
7+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
78

89

910
class CachedGemmaAttention(keras.layers.Layer):
@@ -117,6 +118,36 @@ def _compute_attention(
117118
query_normalization = 1 / np.sqrt(
118119
self.hidden_dim // self.num_query_heads
119120
)
121+
use_dot_product_attention = not (
122+
self.dropout > 0.0 or (len(q.shape) != 4)
123+
)
124+
if has_flash_attention_support() and use_dot_product_attention:
125+
if self.dropout > 0.0:
126+
raise ValueError(
127+
"Flash attention does not support dropout. "
128+
"Please set `dropout` to 0.0."
129+
)
130+
if attention_mask is not None:
131+
while len(attention_mask.shape) < 4:
132+
attention_mask = ops.expand_dims(
133+
attention_mask, axis=1
134+
) # Add dimension for num_heads
135+
if attention_mask.shape[1] != self.num_query_heads:
136+
attention_mask = ops.tile(
137+
attention_mask, [1, self.num_query_heads, 1, 1]
138+
)
139+
140+
attention_output = ops.dot_product_attention(
141+
query=q,
142+
key=k,
143+
value=v,
144+
bias=None,
145+
mask=attention_mask,
146+
scale=query_normalization,
147+
is_causal=True,
148+
flash_attention=True,
149+
)
150+
return attention_output
120151

121152
q *= ops.cast(query_normalization, dtype=q.dtype)
122153
q_shape = ops.shape(q)
@@ -131,8 +162,8 @@ def _compute_attention(
131162
)
132163
b, q_len, _, _, h = ops.shape(q)
133164

165+
# Fallback to standard attention if flash attention is disabled
134166
attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
135-
136167
if self.logit_soft_cap is not None:
137168
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
138169
attention_logits = ops.multiply(

keras_hub/src/models/pali_gemma/pali_gemma_backbone.py

+7
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,13 @@ def __init__(
274274
# Keep the image_sequence_length as a backbone property for easy access.
275275
self.image_sequence_length = self.vit_encoder.image_sequence_length
276276

277+
def get_lora_target_names(self):
278+
target_names = super().get_lora_target_names()
279+
280+
# Add these for `PaliGemmaVITAttention`.
281+
target_names += ["query_proj", "value_proj"]
282+
return target_names
283+
277284
def get_config(self):
278285
config = super().get_config()
279286
config.update(

keras_hub/src/models/pali_gemma/pali_gemma_presets.py

+93-3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,96 @@
8383
},
8484
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_ft_docci_10b_448/2",
8585
},
86+
"pali_gemma2_mix_3b_224": {
87+
"metadata": {
88+
"description": (
89+
"3 billion parameter, image size 224, 27-layer for "
90+
"SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
91+
"model. This model has been fine-tuned on a wide range of "
92+
"vision-language tasks and domains."
93+
),
94+
"params": 3032094960,
95+
"official_name": "PaliGemma2",
96+
"path": "pali_gemma2",
97+
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
98+
},
99+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_224/2",
100+
},
101+
"pali_gemma2_mix_3b_448": {
102+
"metadata": {
103+
"description": (
104+
"3 billion parameter, image size 448, 27-layer for "
105+
"SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
106+
"model. This model has been fine-tuned on a wide range of "
107+
"vision-language tasks and domains."
108+
),
109+
"params": 3032979696,
110+
"official_name": "PaliGemma2",
111+
"path": "pali_gemma2",
112+
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
113+
},
114+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_3b_448/2",
115+
},
116+
"pali_gemma2_mix_10b_224": {
117+
"metadata": {
118+
"description": (
119+
"10 billion parameter, image size 224, 27-layer for "
120+
"SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage "
121+
"model. This model has been fine-tuned on a wide range of "
122+
"vision-language tasks and domains."
123+
),
124+
"params": 9662409456,
125+
"official_name": "PaliGemma2",
126+
"path": "pali_gemma2",
127+
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
128+
},
129+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_224/2",
130+
},
131+
"pali_gemma2_mix_10b_448": {
132+
"metadata": {
133+
"description": (
134+
"10 billion parameter, image size 448, 27-layer for "
135+
"SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage "
136+
"model. This model has been fine-tuned on a wide range of "
137+
"vision-language tasks and domains."
138+
),
139+
"params": 9663294192,
140+
"official_name": "PaliGemma2",
141+
"path": "pali_gemma2",
142+
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
143+
},
144+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_448/2",
145+
},
146+
"pali_gemma2_mix_28b_224": {
147+
"metadata": {
148+
"description": (
149+
"28 billion parameter, image size 224, 27-layer for "
150+
"SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage "
151+
"model. This model has been fine-tuned on a wide range of "
152+
"vision-language tasks and domains."
153+
),
154+
"params": 27650192112,
155+
"official_name": "PaliGemma2",
156+
"path": "pali_gemma2",
157+
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
158+
},
159+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_224/2",
160+
},
161+
"pali_gemma2_mix_28b_448": {
162+
"metadata": {
163+
"description": (
164+
"28 billion parameter, image size 448, 27-layer for "
165+
"SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage "
166+
"model. This model has been fine-tuned on a wide range of "
167+
"vision-language tasks and domains."
168+
),
169+
"params": 27650192112,
170+
"official_name": "PaliGemma2",
171+
"path": "pali_gemma2",
172+
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
173+
},
174+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_448/2",
175+
},
86176
"pali_gemma2_pt_3b_224": {
87177
"metadata": {
88178
"description": (
@@ -181,7 +271,7 @@
181271
"model. This model has been pre-trained on a mixture of "
182272
"datasets."
183273
),
184-
"params": 9662409456,
274+
"params": 27650192112,
185275
"official_name": "PaliGemma2",
186276
"path": "pali_gemma2",
187277
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
@@ -196,7 +286,7 @@
196286
"model. This model has been pre-trained on a mixture of "
197287
"datasets."
198288
),
199-
"params": 9663294192,
289+
"params": 27650192112,
200290
"official_name": "PaliGemma2",
201291
"path": "pali_gemma2",
202292
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
@@ -211,7 +301,7 @@
211301
"model. This model has been pre-trained on a mixture of "
212302
"datasets."
213303
),
214-
"params": 9666833136,
304+
"params": 27650192112,
215305
"official_name": "PaliGemma2",
216306
"path": "pali_gemma2",
217307
"model_card": "https://www.kaggle.com/models/google/paligemma-2",

keras_hub/src/utils/keras_utils.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,19 @@ def standardize_data_format(data_format):
5656

5757

5858
def has_flash_attention_support():
59-
if hasattr(keras.config, "is_flash_attention_enabled"):
59+
if (
60+
hasattr(keras.config, "is_flash_attention_enabled")
61+
and keras.config.backend() == "jax"
62+
):
63+
try:
64+
from jax.nn import dot_product_attention as dot_product_attention
65+
except ImportError:
66+
logging.warning(
67+
"Flash attention is not supported in your current JAX version. "
68+
"Please update it by following the official guide: "
69+
"https://jax.readthedocs.io/en/latest/installation.html"
70+
)
71+
return False
6072
return True
6173
else:
6274
return False

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Tensorflow.
2-
tensorflow-cpu~=2.18
2+
tensorflow-cpu~=2.18.0;sys_platform != 'darwin'
3+
tensorflow~=2.18.0;sys_platform == 'darwin'
34
tensorflow-text~=2.18
45

56
# Torch.

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ def get_version(rel_path):
4545
"regex",
4646
"rich",
4747
"kagglehub",
48-
"tensorflow-text",
4948
],
5049
extras_require={
5150
"extras": [
5251
"rouge-score",
5352
"sentencepiece",
5453
],
54+
"nlp": ["tensorflow-text"],
5555
},
5656
# Supported Python versions
5757
python_requires=">=3.9",

tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py

+6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
"pali_gemma2_10b_ft_docci_448": (
7070
"google/paligemma-2/jax/paligemma2-10b-ft-docci-448"
7171
),
72+
"pali_gemma2_3b_mix_224": "google/paligemma-2/jax/paligemma2-3b-mix-224",
73+
"pali_gemma2_3b_mix_448": "google/paligemma-2/jax/paligemma2-3b-mix-448",
74+
"pali_gemma2_10b_mix_224": "google/paligemma-2/jax/paligemma2-10b-mix-224",
75+
"pali_gemma2_10b_mix_448": "google/paligemma-2/jax/paligemma2-10b-mix-448",
76+
"pali_gemma2_28b_mix_224": "google/paligemma-2/jax/paligemma2-28b-mix-224",
77+
"pali_gemma2_28b_mix_448": "google/paligemma-2/jax/paligemma2-28b-mix-448",
7278
"pali_gemma2_3b_pt_224": "google/paligemma-2/jax/paligemma2-3b-pt-224",
7379
"pali_gemma2_3b_pt_448": "google/paligemma-2/jax/paligemma2-3b-pt-448",
7480
"pali_gemma2_3b_pt_896": "google/paligemma-2/jax/paligemma2-3b-pt-896",

0 commit comments

Comments
 (0)