- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
Granite speech speedup + model saving bugfix #39028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Granite speech speedup + model saving bugfix #39028
Conversation
avoid unused parameters that DDP does not like
trainers often pass this argument automatically
this ensures save_pretrained will not crash when saving the processor during training https://github.com/huggingface/transformers/blob/d5d007a1a0f0c11a726a54c8f00bd71825f84d02/src/transformers/feature_extraction_utils.py#L595
- avoid modifying `self._hf_peft_config_loaded` when saving - adapter_config automatically points to the original base model - a finetuned version should point to the model save dir. - fixing model weights names, that are changed by adding an adapter.
…transformers into granite_speech_updates
| # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) | ||
| # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale | ||
| # einsum gives x30 speedup: | ||
| pos_attn = torch.einsum('b m h c d, c r d -> b m h c r', query_states, rel_pos_emb) * self.scale | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
einsum runs significantly faster (measured with 500 repetitions), and has a smaller memory footprint:
einsum: 25.089 ms
existing (explicit dot): 594.220 ms
I was hoping we could use the einsum implementation to speed up inference/finetuning, and keep the equivalent formulation for readability.
| if is_peft_available and self._hf_peft_config_loaded: | ||
| super().save_pretrained(*args, **kwargs) | ||
| adapter_name = self._get_adapter_name() | ||
| self.peft_config[adapter_name].base_model_name_or_path = save_directory | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ensures the adapter config points to the finetuned model
| self._hf_peft_config_loaded = False | ||
| super().save_pretrained(*args, **kwargs) | ||
| super().save_pretrained(save_directory, *args, **kwargs) | ||
| self._hf_peft_config_loaded = prev_val | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bugfix, ensuring save_pretrained would not change the original value
| @staticmethod | ||
| def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]: | ||
| # save the model with the original weights format | ||
| return key.replace(".base_layer",""), False | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The adapter changes the original parameter names by adding: .base_layer to each one.
This hack enables save_pretrained() and from_pretrained() to work as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! happy to merge as is but in general einsum is not magic, there is an equivalent implementation out there that only uses matrix notation!
| # faster implementation, equivalent to: | ||
| # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) | ||
| # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale | ||
| # einsum gives x30 speedup: | ||
| pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If einsum is possible I am fairly sur there is a way to do this with just matrix notation! We always avoid einsum in transformers!
Let's add a TODO here as you probably want to have this merged fast!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really tried finding an alternative!
matmul is not a great choice, since we need vectorized dot product, and matmul would have redundant computations.
vecdot was a promising direction, but it was actually slower.
I think the einsum speedup has to do with either broadcasting, kernels for half precision - not sure.
I'll make sure to update it if I'll learn something new.
Adding a todo - Thanks Arthur!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @avihu111, just saw this while on watch! Try
(query_states.unsqueeze(-2) @ rel_pos_emb.transpose(-1, -2)).squeeze(-2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Rocketknight1 for a great suggestion.
That's a cool trick to get bmm to perform vectorized dot product.
For some reason, it still performs on par with the explicit dot product, which is still x50 times slower than einsum 😮
I ran the following code to compare all methods:
        for method in ["einsum", "explicit_dot", "vecdot", "bmm"]:
            with torch.amp.autocast("cuda", torch.bfloat16):
                t1 = time.time()
                for _ in range(500):
                    if method == "einsum":
                        cur_pos_attn = torch.einsum('b m h c d, c r d -> b m h c r', query_states, rel_pos_emb) * self.scale
                    elif method == "explicit_dot":
                        rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
                        cur_pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
                    elif method == "vecdot":
                        rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
                        cur_pos_attn = torch.linalg.vecdot(query_states.unsqueeze(-2), rel_pos_emb_expanded, dim=-1) * self.scale
                    elif method == "bmm":
                        cur_pos_attn = (query_states.unsqueeze(-2) @ rel_pos_emb.transpose(-1, -2)).squeeze(-2) * self.scale
                print(f"{method} took {(time.time() - t1) * 1000:.3f} ms\t max abs diff is {(cur_pos_attn - pos_attn).abs().max().item():.5f}")
Results:
einsum took 27.996 ms    max abs diff is 0.00000
explicit_dot took 1450.317 ms    max abs diff is 0.01862
vecdot took 919.140 ms   max abs diff is 0.03125
bmm took 1426.787 ms     max abs diff is 0.00195
| Thanks for the thorough checks ! 🤗 makes a lot of sense when we have this huge perf diff! | 
What does this PR do?
Speeding up the encoder
Reverting Shaw's positional embedding calculation to einsum results in a significant speedup in both inference/training runtime.
We found it to be x30 times faster than the current explicit dot product using bfloat16.
I kept the explicit dot product in a comment for readability.
I hope that it would be possible.
Fixing issues with loading and saving with an adapter
_hf_peft_config_loadedwhen savingMaybe there's a better solution for the problems I was facing - I'll be happy to hear your opinion.
I added comments on each code change, along with the necessary context and justification for the change.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @eustlb can you give that a look? 🙏
CC: @avishaiElmakies @alex-jw-brooks