Skip to content

Commit ebc56b4

Browse files
authored
Add query_proj, value_proj to target names for enable_lora (keras-team#2107)
1 parent aac7257 commit ebc56b4

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

keras_hub/src/models/backbone.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,23 @@ def save_to_preset(self, preset_dir):
186186
saver = get_preset_saver(preset_dir)
187187
saver.save_backbone(self)
188188

189+
def get_lora_target_names(self):
190+
"""Returns list of layer names which are to be LoRA-fied.
191+
192+
Subclasses can override this method if the names of layers to be
193+
LoRa-fied are different.
194+
"""
195+
return ["query_dense", "value_dense", "query", "value"]
196+
189197
def enable_lora(self, rank):
190198
"""Enable Lora on the backbone.
191199
192200
Calling this method will freeze all weights on the backbone,
193201
while enabling Lora on the query & value `EinsumDense` layers
194202
of the attention layers.
195203
"""
196-
target_names = ["query_dense", "value_dense", "query", "value"]
204+
target_names = self.get_lora_target_names()
205+
197206
self.trainable = True
198207
self._lora_enabled_layers = []
199208
self._lora_rank = rank

keras_hub/src/models/pali_gemma/pali_gemma_backbone.py

+7
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,13 @@ def __init__(
274274
# Keep the image_sequence_length as a backbone property for easy access.
275275
self.image_sequence_length = self.vit_encoder.image_sequence_length
276276

277+
def get_lora_target_names(self):
278+
target_names = super().get_lora_target_names()
279+
280+
# Add these for `PaliGemmaVITAttention`.
281+
target_names += ["query_proj", "value_proj"]
282+
return target_names
283+
277284
def get_config(self):
278285
config = super().get_config()
279286
config.update(

0 commit comments

Comments
 (0)