Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/transformers/models/granite_speech/modeling_granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) ->
# shaw's relative positional embedding
dist = attention_dists.to(hidden_states.device)
rel_pos_emb = self.rel_pos_emb(dist)
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
# alternative computation of `pos_attn` - for readability
# rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
# pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
# einsum implementation of pos_attn - gives x30 speedup over the alternative
# TODO (@avihu111) find a fast alternative to einsum
pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale

if remainder > 0:
# masked attention in the extended block
Expand Down Expand Up @@ -541,17 +545,34 @@ def generate(self, *args, **kwargs) -> torch.LongTensor:
self.disable_adapters()
return super().generate(*args, input_features=input_features, **kwargs)

def save_pretrained(self, *args, **kwargs):
def save_pretrained(self, save_directory, *args, **kwargs):
# overwrite save_pretrained to first save the adapter if we have one
# NOTE - this will use the base model path we are exporting in the lora
# adapter, which may not necessarily be the best behavior, but for now
# we keep this for portability, since using the local dir causes problems
# if the model is loaded from outside of the current working dir.
if is_peft_available and self._hf_peft_config_loaded:
super().save_pretrained(*args, **kwargs)
adapter_name = self._get_adapter_name()
self.peft_config[adapter_name].base_model_name_or_path = save_directory
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensures the adapter config points to the finetuned model

super().save_pretrained(save_directory, *args, **kwargs)
# Then save the base model afterwards
prev_val = self._hf_peft_config_loaded
self._hf_peft_config_loaded = False
super().save_pretrained(*args, **kwargs)
super().save_pretrained(save_directory, *args, **kwargs)
self._hf_peft_config_loaded = prev_val
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bugfix, ensuring save_pretrained would not change the original value


@staticmethod
def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
# save the model with the original weights format
return key.replace(".base_layer", ""), False

def _fix_state_dict_keys_on_save(self, state_dict):
if is_peft_available and self._hf_peft_config_loaded:
# state dict is only adapter, should keep the same
return state_dict
# rename back the base model state dict
return {
self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key
}

def _get_adapter_name(self):
return list(self.peft_config.keys())[0]


__all__ = [
Expand Down