Skip to content

Conversation

@avihu111
Copy link
Contributor

@avihu111 avihu111 commented Jun 25, 2025

What does this PR do?

Speeding up the encoder

Reverting Shaw's positional embedding calculation to einsum results in a significant speedup in both inference/training runtime.
We found it to be x30 times faster than the current explicit dot product using bfloat16.
I kept the explicit dot product in a comment for readability.
I hope that it would be possible.

Fixing issues with loading and saving with an adapter

  • When saving a checkpoint, the adapter config pointed to the original model, instead of the updated model
  • It fixes a bug, where we changed _hf_peft_config_loaded when saving
  • It reverts a tensor renaming that was triggered by adding an adapter.

Maybe there's a better solution for the problems I was facing - I'll be happy to hear your opinion.
I added comments on each code change, along with the necessary context and justification for the change.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @eustlb can you give that a look? 🙏
CC: @avishaiElmakies @alex-jw-brooks

avihu111 and others added 10 commits June 15, 2025 16:50
avoid unused parameters that DDP does not like
trainers often pass this argument automatically
- avoid modifying `self._hf_peft_config_loaded` when saving
- adapter_config automatically points to the original base model - a finetuned version should point to the model save dir.
- fixing model weights names, that are changed by adding an adapter.
# rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
# pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
# einsum gives x30 speedup:
pos_attn = torch.einsum('b m h c d, c r d -> b m h c r', query_states, rel_pos_emb) * self.scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

einsum runs significantly faster (measured with 500 repetitions), and has a smaller memory footprint:

einsum: 25.089 ms
existing (explicit dot): 594.220 ms

I was hoping we could use the einsum implementation to speed up inference/finetuning, and keep the equivalent formulation for readability.

if is_peft_available and self._hf_peft_config_loaded:
super().save_pretrained(*args, **kwargs)
adapter_name = self._get_adapter_name()
self.peft_config[adapter_name].base_model_name_or_path = save_directory
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensures the adapter config points to the finetuned model

self._hf_peft_config_loaded = False
super().save_pretrained(*args, **kwargs)
super().save_pretrained(save_directory, *args, **kwargs)
self._hf_peft_config_loaded = prev_val
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bugfix, ensuring save_pretrained would not change the original value

@staticmethod
def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
# save the model with the original weights format
return key.replace(".base_layer",""), False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adapter changes the original parameter names by adding: .base_layer to each one.
This hack enables save_pretrained() and from_pretrained() to work as expected.

@avihu111 avihu111 marked this pull request as draft June 25, 2025 08:17
@avihu111 avihu111 changed the title Granite speech updates Granite speech speedup + model saving bugfix Jun 25, 2025
@avihu111 avihu111 marked this pull request as ready for review June 25, 2025 12:32
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! happy to merge as is but in general einsum is not magic, there is an equivalent implementation out there that only uses matrix notation!

Comment on lines 162 to 166
# faster implementation, equivalent to:
# rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
# pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
# einsum gives x30 speedup:
pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If einsum is possible I am fairly sur there is a way to do this with just matrix notation! We always avoid einsum in transformers!
Let's add a TODO here as you probably want to have this merged fast!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really tried finding an alternative!
matmul is not a great choice, since we need vectorized dot product, and matmul would have redundant computations.
vecdot was a promising direction, but it was actually slower.
I think the einsum speedup has to do with either broadcasting, kernels for half precision - not sure.
I'll make sure to update it if I'll learn something new.
Adding a todo - Thanks Arthur!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @avihu111, just saw this while on watch! Try

(query_states.unsqueeze(-2) @ rel_pos_emb.transpose(-1, -2)).squeeze(-2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Rocketknight1 for a great suggestion.
That's a cool trick to get bmm to perform vectorized dot product.
For some reason, it still performs on par with the explicit dot product, which is still x50 times slower than einsum 😮

I ran the following code to compare all methods:

        for method in ["einsum", "explicit_dot", "vecdot", "bmm"]:
            with torch.amp.autocast("cuda", torch.bfloat16):
                t1 = time.time()
                for _ in range(500):
                    if method == "einsum":
                        cur_pos_attn = torch.einsum('b m h c d, c r d -> b m h c r', query_states, rel_pos_emb) * self.scale
                    elif method == "explicit_dot":
                        rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
                        cur_pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
                    elif method == "vecdot":
                        rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
                        cur_pos_attn = torch.linalg.vecdot(query_states.unsqueeze(-2), rel_pos_emb_expanded, dim=-1) * self.scale
                    elif method == "bmm":
                        cur_pos_attn = (query_states.unsqueeze(-2) @ rel_pos_emb.transpose(-1, -2)).squeeze(-2) * self.scale

                print(f"{method} took {(time.time() - t1) * 1000:.3f} ms\t max abs diff is {(cur_pos_attn - pos_attn).abs().max().item():.5f}")

Results:

einsum took 27.996 ms    max abs diff is 0.00000
explicit_dot took 1450.317 ms    max abs diff is 0.01862
vecdot took 919.140 ms   max abs diff is 0.03125
bmm took 1426.787 ms     max abs diff is 0.00195

@ArthurZucker ArthurZucker merged commit 22b0a89 into huggingface:main Jun 26, 2025
12 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for the thorough checks ! 🤗 makes a lot of sense when we have this huge perf diff!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants