Skip to content

Commit 22b0a89

Browse files
authored
Granite speech speedup + model saving bugfix (#39028)
* ensure the query is updated during training avoid unused parameters that DDP does not like * avoid a crash when `kwargs` contain `padding=True` trainers often pass this argument automatically * minor * Remove mel_spec lazy init, and rename to mel_filters. this ensures save_pretrained will not crash when saving the processor during training https://github.com/huggingface/transformers/blob/d5d007a1a0f0c11a726a54c8f00bd71825f84d02/src/transformers/feature_extraction_utils.py#L595 * minor - most feature extractors has a `sampling_rate` property * speedup relative position embeddings * fix several issues in model saving/loading: - avoid modifying `self._hf_peft_config_loaded` when saving - adapter_config automatically points to the original base model - a finetuned version should point to the model save dir. - fixing model weights names, that are changed by adding an adapter. * minor * minor * minor * fixing a crash without peft active * add todo to replace einsum
1 parent 1d45d90 commit 22b0a89

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

src/transformers/models/granite_speech/modeling_granite_speech.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,12 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) ->
159159
# shaw's relative positional embedding
160160
dist = attention_dists.to(hidden_states.device)
161161
rel_pos_emb = self.rel_pos_emb(dist)
162-
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
163-
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
162+
# alternative computation of `pos_attn` - for readability
163+
# rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
164+
# pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
165+
# einsum implementation of pos_attn - gives x30 speedup over the alternative
166+
# TODO (@avihu111) find a fast alternative to einsum
167+
pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale
164168

165169
if remainder > 0:
166170
# masked attention in the extended block
@@ -541,17 +545,34 @@ def generate(self, *args, **kwargs) -> torch.LongTensor:
541545
self.disable_adapters()
542546
return super().generate(*args, input_features=input_features, **kwargs)
543547

544-
def save_pretrained(self, *args, **kwargs):
548+
def save_pretrained(self, save_directory, *args, **kwargs):
545549
# overwrite save_pretrained to first save the adapter if we have one
546-
# NOTE - this will use the base model path we are exporting in the lora
547-
# adapter, which may not necessarily be the best behavior, but for now
548-
# we keep this for portability, since using the local dir causes problems
549-
# if the model is loaded from outside of the current working dir.
550550
if is_peft_available and self._hf_peft_config_loaded:
551-
super().save_pretrained(*args, **kwargs)
551+
adapter_name = self._get_adapter_name()
552+
self.peft_config[adapter_name].base_model_name_or_path = save_directory
553+
super().save_pretrained(save_directory, *args, **kwargs)
552554
# Then save the base model afterwards
555+
prev_val = self._hf_peft_config_loaded
553556
self._hf_peft_config_loaded = False
554-
super().save_pretrained(*args, **kwargs)
557+
super().save_pretrained(save_directory, *args, **kwargs)
558+
self._hf_peft_config_loaded = prev_val
559+
560+
@staticmethod
561+
def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
562+
# save the model with the original weights format
563+
return key.replace(".base_layer", ""), False
564+
565+
def _fix_state_dict_keys_on_save(self, state_dict):
566+
if is_peft_available and self._hf_peft_config_loaded:
567+
# state dict is only adapter, should keep the same
568+
return state_dict
569+
# rename back the base model state dict
570+
return {
571+
self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key
572+
}
573+
574+
def _get_adapter_name(self):
575+
return list(self.peft_config.keys())[0]
555576

556577

557578
__all__ = [

0 commit comments

Comments
 (0)