-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Support for Mistral Model in Llama-Adapter Method #1433
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding support for Mistral. This looks very promising. I only have a couple of comments, please check them out. Also, would it be possible to add a unit test to tests/test_adaption_prompt.py
involving a small mistral model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @PrakharSaxena24, thank you for the PR but the logic seems incorrect as mentioned in the comments. Please look at the code of Mistral modeling file on the correct way of going about GQA (grouped query attention).
adapter_k = ( | ||
key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim) | ||
key.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The head dim shouldn't change but the number of heads should be reduced in GQA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! Thanks a lot, this seems correct.
Will edit this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I think I will need to do the same in utils.py
@@ -100,6 +104,15 @@ def forward(self, **kwargs): | |||
query_states = compute_query_states(model=self.model, **kwargs) | |||
|
|||
previous_dtype = query_states.dtype | |||
|
|||
# Reshape and average the extra tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to reshape and avg query states as the above key shape is (bsz, adapter_seq_len, num_kv_heads, head_dim), the value shape is (bsz, adapter_seq_len, num_kv_heads, head_dim) and query shape is (bsz, adapter_seq_len, num_heads, head_dim). Now, you would need to repeat the num_kv_heads
to match num_heads
as done in https://github.com/huggingface/transformers/blob/1c31b7aa3bb4e7ef24c77596d2a76f45a770159f/src/transformers/models/mistral/modeling_mistral.py#L193. After that the attn computation is same as normal MHA case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, so rather than repeating the adapter output, I should repeat adapter_k and adapter_v.
adapter_k = torch.repeat_interleave( adapter_k, repeats=factor, dim=1 )
adapter_v = torch.repeat_interleave( adapter_v, repeats=factor, dim=1 )
as the key, value shape is (bsz, num_kv_heads, adapter_seq_len, head_dim), (dim 1 for num_kv_heads)
Does this makes sense?
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
@BenjaminBossan @pacman100 thank you for you kind comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot. From my side, there are only a few comments left, please take a look.
tests/test_adaption_prompt.py
Outdated
@@ -78,7 +106,19 @@ def test_attributes(self) -> None: | |||
self.assertTrue(hasattr(model, "from_pretrained")) | |||
self.assertTrue(hasattr(model, "push_to_hub")) | |||
|
|||
#Test Mistral | |||
if self.mistral_available: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of attaching the mistral tests to the llama tests, could you please create a separate test for each? You can decorate them with @unittest.skipIf(not is_mistral_available())
to avoid the if self.mistral_available:
line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the advice! I will do it. Good feedback is very helpful for me to learn and apply the best practices :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR, it looks good now. I also appreciate extending the existing tests.
Let's wait for @pacman100's final review before merging.
@pacman100 please let me know it there is something to change in the PR. |
@PrakharSaxena24 Heads up, there is now a merge conflict, which stems from a recent PR where we switched from unittest-style |
@BenjaminBossan Changed the assert to pytest style. Moreover there was a conflict in utils.py which was also resolved. Please have a look! Edited: Seems like something is breaking in transformers (ci), will have a look at it tommorow, however if you have any idea why that will be very helpful! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, the PR LGTM. Let's wait for a final review by @pacman100 before merging.
@BenjaminBossan |
Yes, there will be conflicts, so whoever comes last will have to resolve them :) I don't think it's a huge issue. Since both PRs have tests, we should hopefully have the guard rails to ensure that resolving the merge conflict won't lead to a regression in the other PR. |
Thank you for the reply! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @PrakharSaxena24 for supporting Mistral with Adaptation Prompt and the detailed tests! ✨
@BenjaminBossan @pacman100 |
* Support Mistral For llama-adapter * Update src/peft/tuners/adaption_prompt/layer.py Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> * Update src/peft/tuners/adaption_prompt/layer.py Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> * corrected logic and added test * removed commented out code * Added seperate test functions for mistral * missed self.assert * ruff formatting --------- Co-authored-by: Prakhar Saxena <prakharsxena11111@gmail.com> Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Hello PEFT team,
Purpose of This PR:
Add support for Mistral model for llama-adapter method.
Background:
I wanted to test how does the method in this paper works with Mistral based models. compared to Lora. Initially I though that since the architecture of Llama and Mistral are almost the same, this could be achieved by just changing the config, however I found out that the mistral models k_proj and v_proj dimensions are different from that of Llama.
Hence I added the model support for Mistral in the llama-adapter method (the naming is confusing).
I hope it will be useful for anyone else willing to experiment with different methods.
Request for Review:
Please provide review and let me know if my implementation makes sense.
Thank you for all your hardwork!