Skip to content

Commit b0fe238

Browse files
Fix save pretrained for granite speech
1 parent 00503bb commit b0fe238

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

src/transformers/models/granite_speech/configuration_granite_speech.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838

3939
## adapted from transformers.models.blip.configuration_blip_2.Blip2VisionConfig
4040
class GraniteSpeechProjectorConfig(PretrainedConfig):
41-
model_type = "blip_2_qformer"
41+
model_type = "granite_speech_qformer"
4242

4343
def __init__(
4444
self,
@@ -107,9 +107,7 @@ def __init__(
107107
text_config = CONFIG_MAPPING["granite"]()
108108

109109
if isinstance(projector_config, dict):
110-
# TODO - Make this generic after blip2qformer is moved out to its own model dir.
111-
if projector_config["model_type"] != "blip_2_qformer":
112-
raise ValueError("Granite speech currently requires blip2 qformer as its encoder!")
110+
# TODO - In the future, we should make this generic.
113111
projector_config = GraniteSpeechProjectorConfig(**projector_config)
114112
elif projector_config is None:
115113
projector_config = GraniteSpeechProjectorConfig()

src/transformers/models/granite_speech/modeling_granite_speech.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,18 @@ def generate(self, *args, **kwargs):
13771377
self.disable_adapters()
13781378
return super().generate(*args, input_features=input_features, **kwargs)
13791379

1380+
def save_pretrained(self, *args, **kwargs):
1381+
# overwrite save_pretrained to first save the adapter if we have one
1382+
# NOTE - this will use the base model path we are exporting in the lora
1383+
# adapter, which may not necessarily be the best behavior, but for now
1384+
# we keep this for portability, since using the local dir causes problems
1385+
# if the model is loaded from outside of the current working dir.
1386+
if is_peft_available and self._hf_peft_config_loaded:
1387+
super().save_pretrained(*args, **kwargs)
1388+
# Then save the base model afterwards
1389+
self._hf_peft_config_loaded = False
1390+
super().save_pretrained(*args, **kwargs)
1391+
13801392

13811393
__all__ = [
13821394
"GraniteSpeechForConditionalGeneration",

tests/models/granite_speech/test_modeling_granite_speech.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
"layer_norm_eps": 1e-12,
9797
"llm_dim": 32,
9898
"max_position_embeddings": 2048,
99-
"model_type": "blip_2_qformer",
99+
"model_type": "granite_speech_qformer",
100100
"num_attention_heads": 4,
101101
"num_hidden_layers": 2,
102102
"position_embedding_type": "absolute",

0 commit comments

Comments
 (0)