-
Notifications
You must be signed in to change notification settings - Fork 31k
Granite speech speedup + model saving bugfix #39028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ArthurZucker
merged 16 commits into
huggingface:main
from
avihu111:granite_speech_updates
Jun 26, 2025
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
db4a4af
ensure the query is updated during training
avihu111 8ee3429
avoid a crash when `kwargs` contain `padding=True`
avihu111 8dec2ba
minor
avihu111 4db4c99
Remove mel_spec lazy init, and rename to mel_filters.
avihu111 98844ec
minor - most feature extractors has a `sampling_rate` property
avihu111 6e68d8c
Merge branch 'main' into granite_speech_updates
avihu111 7064db7
speedup relative position embeddings
avihu111 e94d0a1
Merge branch 'huggingface:main' into granite_speech_updates
avihu111 9c06f95
fix several issues in model saving/loading:
avihu111 6c2db62
Merge branch 'granite_speech_updates' of https://github.com/avihu111/…
avihu111 8b79d9e
minor
avihu111 313e4a2
minor
avihu111 29a69be
minor
avihu111 bc152b9
fixing a crash without peft active
avihu111 0ebb8f0
add todo to replace einsum
avihu111 0491123
Merge branch 'main' into granite_speech_updates
avihu111 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,8 +159,12 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> | |
| # shaw's relative positional embedding | ||
| dist = attention_dists.to(hidden_states.device) | ||
| rel_pos_emb = self.rel_pos_emb(dist) | ||
| rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) | ||
| pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale | ||
| # alternative computation of `pos_attn` - for readability | ||
| # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) | ||
| # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale | ||
| # einsum implementation of pos_attn - gives x30 speedup over the alternative | ||
| # TODO (@avihu111) find a fast alternative to einsum | ||
| pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale | ||
|
|
||
| if remainder > 0: | ||
| # masked attention in the extended block | ||
|
|
@@ -541,17 +545,34 @@ def generate(self, *args, **kwargs) -> torch.LongTensor: | |
| self.disable_adapters() | ||
| return super().generate(*args, input_features=input_features, **kwargs) | ||
|
|
||
| def save_pretrained(self, *args, **kwargs): | ||
| def save_pretrained(self, save_directory, *args, **kwargs): | ||
| # overwrite save_pretrained to first save the adapter if we have one | ||
| # NOTE - this will use the base model path we are exporting in the lora | ||
| # adapter, which may not necessarily be the best behavior, but for now | ||
| # we keep this for portability, since using the local dir causes problems | ||
| # if the model is loaded from outside of the current working dir. | ||
| if is_peft_available and self._hf_peft_config_loaded: | ||
| super().save_pretrained(*args, **kwargs) | ||
| adapter_name = self._get_adapter_name() | ||
| self.peft_config[adapter_name].base_model_name_or_path = save_directory | ||
| super().save_pretrained(save_directory, *args, **kwargs) | ||
| # Then save the base model afterwards | ||
| prev_val = self._hf_peft_config_loaded | ||
| self._hf_peft_config_loaded = False | ||
| super().save_pretrained(*args, **kwargs) | ||
| super().save_pretrained(save_directory, *args, **kwargs) | ||
| self._hf_peft_config_loaded = prev_val | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bugfix, ensuring |
||
|
|
||
| @staticmethod | ||
| def _fix_state_dict_key_on_save(key) -> tuple[str, bool]: | ||
| # save the model with the original weights format | ||
| return key.replace(".base_layer", ""), False | ||
|
|
||
| def _fix_state_dict_keys_on_save(self, state_dict): | ||
| if is_peft_available and self._hf_peft_config_loaded: | ||
| # state dict is only adapter, should keep the same | ||
| return state_dict | ||
| # rename back the base model state dict | ||
| return { | ||
| self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key | ||
| } | ||
|
|
||
| def _get_adapter_name(self): | ||
| return list(self.peft_config.keys())[0] | ||
|
|
||
|
|
||
| __all__ = [ | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ensures the adapter config points to the finetuned model