Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Mistral Model in Llama-Adapter Method #1433

Merged
merged 9 commits into from
Mar 12, 2024

Conversation

PrakharSaxena24
Copy link
Contributor

Hello PEFT team,
Purpose of This PR:
Add support for Mistral model for llama-adapter method.

Background:
I wanted to test how does the method in this paper works with Mistral based models. compared to Lora. Initially I though that since the architecture of Llama and Mistral are almost the same, this could be achieved by just changing the config, however I found out that the mistral models k_proj and v_proj dimensions are different from that of Llama.
Hence I added the model support for Mistral in the llama-adapter method (the naming is confusing).
I hope it will be useful for anyone else willing to experiment with different methods.

Request for Review:

Please provide review and let me know if my implementation makes sense.

Thank you for all your hardwork!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for Mistral. This looks very promising. I only have a couple of comments, please check them out. Also, would it be possible to add a unit test to tests/test_adaption_prompt.py involving a small mistral model?

src/peft/tuners/adaption_prompt/layer.py Show resolved Hide resolved
src/peft/tuners/adaption_prompt/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/adaption_prompt/layer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @PrakharSaxena24, thank you for the PR but the logic seems incorrect as mentioned in the comments. Please look at the code of Mistral modeling file on the correct way of going about GQA (grouped query attention).

adapter_k = (
key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
key.view(1, self.adapter_len, self.model.num_heads, (self.model.head_dim // factor))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head dim shouldn't change but the number of heads should be reduced in GQA.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Thanks a lot, this seems correct.
Will edit this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think I will need to do the same in utils.py

src/peft/tuners/adaption_prompt/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/adaption_prompt/layer.py Outdated Show resolved Hide resolved
@@ -100,6 +104,15 @@ def forward(self, **kwargs):
query_states = compute_query_states(model=self.model, **kwargs)

previous_dtype = query_states.dtype

# Reshape and average the extra tensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to reshape and avg query states as the above key shape is (bsz, adapter_seq_len, num_kv_heads, head_dim), the value shape is (bsz, adapter_seq_len, num_kv_heads, head_dim) and query shape is (bsz, adapter_seq_len, num_heads, head_dim). Now, you would need to repeat the num_kv_heads to match num_heads as done in https://github.com/huggingface/transformers/blob/1c31b7aa3bb4e7ef24c77596d2a76f45a770159f/src/transformers/models/mistral/modeling_mistral.py#L193. After that the attn computation is same as normal MHA case.

Copy link
Contributor Author

@PrakharSaxena24 PrakharSaxena24 Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, so rather than repeating the adapter output, I should repeat adapter_k and adapter_v.
adapter_k = torch.repeat_interleave( adapter_k, repeats=factor, dim=1 )
adapter_v = torch.repeat_interleave( adapter_v, repeats=factor, dim=1 )
as the key, value shape is (bsz, num_kv_heads, adapter_seq_len, head_dim), (dim 1 for num_kv_heads)
Does this makes sense?

PrakharSaxena24 and others added 3 commits February 7, 2024 18:04
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
@PrakharSaxena24
Copy link
Contributor Author

@BenjaminBossan @pacman100 thank you for you kind comments.
, I have corrected the logic and added test.
I did not add test for test_bf16_inference as I could not find any mistral model in here.
Please have a look.
Thanks for your hard work!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. From my side, there are only a few comments left, please take a look.

tests/test_adaption_prompt.py Outdated Show resolved Hide resolved
@@ -78,7 +106,19 @@ def test_attributes(self) -> None:
self.assertTrue(hasattr(model, "from_pretrained"))
self.assertTrue(hasattr(model, "push_to_hub"))

#Test Mistral
if self.mistral_available:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of attaching the mistral tests to the llama tests, could you please create a separate test for each? You can decorate them with @unittest.skipIf(not is_mistral_available()) to avoid the if self.mistral_available: line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the advice! I will do it. Good feedback is very helpful for me to learn and apply the best practices :)

src/peft/tuners/adaption_prompt/config.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this PR, it looks good now. I also appreciate extending the existing tests.

Let's wait for @pacman100's final review before merging.

@PrakharSaxena24
Copy link
Contributor Author

@pacman100 please let me know it there is something to change in the PR.

@BenjaminBossan
Copy link
Member

@PrakharSaxena24 Heads up, there is now a merge conflict, which stems from a recent PR where we switched from unittest-style self.assertFoo(...) to pytest-style plain asserts. Could you please fix the conflict?

@PrakharSaxena24
Copy link
Contributor Author

PrakharSaxena24 commented Feb 19, 2024

@BenjaminBossan Changed the assert to pytest style. Moreover there was a conflict in utils.py which was also resolved. Please have a look!

Edited: Seems like something is breaking in transformers (ci), will have a look at it tommorow, however if you have any idea why that will be very helpful!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, the PR LGTM. Let's wait for a final review by @pacman100 before merging.

@PrakharSaxena24
Copy link
Contributor Author

@BenjaminBossan
Current pr and this PR do very similar things, current PR adds Mistral and also Llama2 34b and 70b (GQA).
The PR above adds GQA(Llama2 34b and 70b).
I think that the there might be conflict if both are merged to the main.
How do you think we should proceed?

@BenjaminBossan
Copy link
Member

I think that the there might be conflict if both are merged to the main.
How do you think we should proceed?

Yes, there will be conflicts, so whoever comes last will have to resolve them :)

I don't think it's a huge issue. Since both PRs have tests, we should hopefully have the guard rails to ensure that resolving the merge conflict won't lead to a regression in the other PR.

@PrakharSaxena24
Copy link
Contributor Author

I don't think it's a huge issue. Since both PRs have tests, we should hopefully have the guard rails to ensure that resolving the merge conflict won't lead to a regression in the other PR.

Thank you for the reply!

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @PrakharSaxena24 for supporting Mistral with Adaptation Prompt and the detailed tests! ✨

@PrakharSaxena24
Copy link
Contributor Author

PrakharSaxena24 commented Mar 12, 2024

@BenjaminBossan @pacman100
Thank you for your time and guidance.

@pacman100 pacman100 merged commit d28fffb into huggingface:main Mar 12, 2024
14 checks passed
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
* Support Mistral For llama-adapter

* Update src/peft/tuners/adaption_prompt/layer.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/peft/tuners/adaption_prompt/layer.py

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* corrected logic and added test

* removed commented out code

* Added seperate test functions for mistral

* missed self.assert

* ruff formatting

---------

Co-authored-by: Prakhar Saxena <prakharsxena11111@gmail.com>
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants