Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when loading local peft model #342

Merged
merged 4 commits into from
May 11, 2023
Merged

Conversation

Opdoop
Copy link
Contributor

@Opdoop Opdoop commented May 5, 2023

Fix the bug in #341.

When the local model weight and the lora weight are saved in different locations, the code will try hf_hub_download and throw HFValidationError: Repo id must be in the form 'repo_name' error twice. In the first time, the error was caught. But the second error will terminate the code.

After loading the local peft model, we can update pretrained_model_name_or_path to pretrained_model to avoid entering the resume_training branch.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Opdoop
Thanks for the PR, I don't know if this is the right fix.
This block: https://github.com/lvwerra/trl/blob/c85cdbdbd0d956bbc6f5a8e04b0036648480e7a1/trl/models/modeling_base.py#L153 should check the presence or not of the adapter model and load it properly using the adapter config.
Can you share a minimal script to reproduce your bug? I suspect you might be loading/saving your model not correctly

@Opdoop
Copy link
Contributor Author

Opdoop commented May 11, 2023

Hi @younesbelkada, thanks for your reply. It may have a chance I using trl in the wrong way.

In my case, lora weight and base model weight are saved in different location, which lead following line break.
https://github.com/lvwerra/trl/blob/c85cdbdbd0d956bbc6f5a8e04b0036648480e7a1/trl/models/modeling_base.py#L205-L220

A minimal script is below.

from trl import AutoModelForCausalLMWithValueHead
import torch
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    "../local_lora_path",
    torch_dtype=torch.float16,
)

where I pass the relative path of my local fine-tuned model. The if block in line 153 is running fine. It correctly loads lora weight and finds the corresponding base model weight.

That is pretrained_model_name_or_path is the string of local_lora_path. local_lora_path only have adapter_config.json and adapter_model.bin two files. And in adapter_config.json has "base_model_name_or_path": "../yahma-llama-7b-hf/", it is the base model path. All weight/config/tokenizer of the base model are saved in ../yahma-llama-7b-hf/.

So the following loading code worked correctly:
https://github.com/lvwerra/trl/blob/c85cdbdbd0d956bbc6f5a8e04b0036648480e7a1/trl/models/modeling_base.py#L153-L167

Then, it will enter the if branch of resume_training, as the pretrained_model_name_or_path is the string of local_lora_path. But the local_lora_path do not have pytorch_model.bin or pytorch_model.bin.index.json. So it will break the code.
https://github.com/lvwerra/trl/blob/c85cdbdbd0d956bbc6f5a8e04b0036648480e7a1/trl/models/modeling_base.py#L205-L220

Thus, I think adding pretrained_model_name_or_path = pretrained_model will solve this problem.

Looking forward to your response.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, this is much clearer!
I managed to reproduce the issue, it appears that the issue happens only if you give a path to a local peft model that has been not saved with AutoModelForCausalLMWithValueHead
As you can see here: https://github.com/lvwerra/trl/blob/dec9993129f9b77cb81c46f2662119a824b4f6cc/tests/test_peft_models.py#L128 there is a test that checks if loading a model that has been saved using peft works correctly. However this assumes you have saved the model with AutoModelForCausalLMWithValueHead.save_pretrained(), thus saves the v_head inside pytorch_model.bin file.
I think the right fix is to add a try/except check starting from the line 218 to 220 to assert if the file ***.index.json does exists or not, then assign a new boolean is_resuming_training that will be used. If is_resuming_training execute the lines 222-229, else not execute these lines and assume we are not continuing training
It would be also great if you can a test similarly as https://github.com/lvwerra/trl/blob/dec9993129f9b77cb81c46f2662119a824b4f6cc/tests/test_peft_models.py#L128 , but we can do it in a follow up PR!
Thanks again for spotting this

@Opdoop
Copy link
Contributor Author

Opdoop commented May 11, 2023

Hi, @younesbelkada . Thanks for your quick response!!!
Yes, you are right. The lora model is a finetuned lora model using peft. The model is not saved with AutoModelForCausalLMWithValueHead.

I try to initialize the lora weight using a finetuned lora before ppo training. I think it may help the training of ppo. (I'm not sure. What do you think?)
I'm not sure I understand correctly. Should I close this pull request and open a new request to implement the try/except check?

@younesbelkada
Copy link
Contributor

Thanks for iterating!
I think you can just continue on that PR, and let me know when it's ready for review again

Opdoop added a commit to Opdoop/trl that referenced this pull request May 11, 2023
1. Implement the fix logic described in huggingface#342 (review)

2. Set peft lora weight to trainable.
Fix loading bug when load lora model but not resuming training

1. Implement the fix logic described in huggingface#342 (review)

2. Set peft lora weight to trainable.
@Opdoop
Copy link
Contributor Author

Opdoop commented May 11, 2023

Hi @younesbelkada. I implement the try/except block. I think it is ready for review again.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 11, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this! I left one open question - I think we should address the change about is_trainable in another PR - after that we should be good to merge

@@ -164,7 +163,9 @@ class and the arguments that are specific to trl models. The kwargs
peft_config.base_model_name_or_path, *model_args, **pretrained_kwargs
)

pretrained_model = PeftModel.from_pretrained(pretrained_model, pretrained_model_name_or_path)
pretrained_model = PeftModel.from_pretrained(
pretrained_model, pretrained_model_name_or_path, is_trainable=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is_trainable is only relevant when the peft config is different from LoraConfig no?
Also what if a user loads the model just for inference? I think we can leave this change in a future PR and add is_trainable as a kwarg on from_pretained of PreTrainedWrapper

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, my mistake. I think you're right. I have only considered my use case. 😅
I'll open another PR for the is_trainable as a kwarg.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome thanks so much!

Leave is_trainable to future PR.
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Thanks a lot for fixing this bug 💪 This looks great to me!

@younesbelkada younesbelkada requested a review from lvwerra May 11, 2023 15:09
Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix - looks good to me, maybe we can add a test so it doesn't happen again in the future.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Opdoop do you feel comfortable adding new tests?
The new test would be similar to this test: https://github.com/lvwerra/trl/blob/e0172fc8ecb11dfac0410feb1bf2a5c80ec9418b/tests/test_peft_models.py#L128 otherwise I can do it by pushing on your branch

@Opdoop
Copy link
Contributor Author

Opdoop commented May 11, 2023

I'm willing to try to add a test. But I didn't get it. Do you mean we pass a 'lora_path' to AutoModelForCausalLMWithValueHead.from_pretrained('lora_path')? I'm not sure which model in https://huggingface.co/trl-internal-testing is suitable for this test.

@younesbelkada
Copy link
Contributor

The workflow for that test would be:
save the dummy peft model inside tmp_dir

 causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
 pretrained_model = get_peft_model(causal_lm_model, self.lora_config)

with ... as tmp_dir:
    pretrained_model.save_pretrained(tmp_dir)

and from there load the AutoModelForCausalLMWithValueHead directly from tmp_dir and there is no need to check for the v_head weights afterwards I think, I would just check for the presence of the adapter model & config on tmp_dir

@Opdoop
Copy link
Contributor Author

Opdoop commented May 11, 2023

Thanks for the detailed explanation. I'll try it. 👍

Check that the model saved with peft class interface can be loaded properly.
@Opdoop
Copy link
Contributor Author

Opdoop commented May 11, 2023

❤Thank you for your careful and patient guidance, I learned a lot from it.

@younesbelkada younesbelkada requested a review from lvwerra May 11, 2023 16:35
@younesbelkada younesbelkada merged commit 31cc361 into huggingface:main May 11, 2023
@Opdoop Opdoop deleted the patch-1 branch May 12, 2023 07:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants