Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPOTrainer + LoRA and Continued Training #2707

Open
kooryan opened this issue Jan 30, 2025 · 0 comments
Open

PPOTrainer + LoRA and Continued Training #2707

kooryan opened this issue Jan 30, 2025 · 0 comments
Labels
⏳ needs more info Additional information or clarification is required to proceed ⚡ PEFT Related to PEFT 🏋 PPO Related to PPO

Comments

@kooryan
Copy link

kooryan commented Jan 30, 2025

Hi all,

So, currently, I’m training a model with PPOTrainer and Lora.

When I do

model.save_pretrained(…)

It saves both an adapter_model.safetensors and a pytorch_model.bin

What is the difference between the two. There are both in the same file directory, but it seems when I load a model via from_pretrained it utilizes the adapter_model.

Does the pytorch_model.bin also have the lora adapters merged?

Additionally, I also want to do continue PPO training from a checkpoint. I load the checkpoint similarly like this and also

policy = AutoModelForCausalLMWithValueHead.from_pretrained(
        model_name,
        peft_config=lora_config,
        quantization_config=nf4_config,
    )
    ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
        model_name, quantization_config=nf4_config
    )

and also directly load the model parameters of the v_head and pretrained_model like this

model_dict = torch.load(
            os.path.join(best_checkpoint_path, "pytorch_model.bin"),
            map_location=lambda s, t: s,
)
v_head_state_dict = {}
pretrained_model_state_dict = {}
for k, v in model_dict.items():
  if k.startswith("v_head."):
      v_head_state_dict[k.replace("v_head.", "")] = v
  else:
      pretrained_model_state_dict[k] = v

policy.v_head.load_state_dict(v_head_state_dict)

_load_state_dict_into_model(
    policy.pretrained_model,
    pretrained_model_state_dict,
    start_prefix=""
)

as well as the optimizer. One way I've trained to load the model weight was loading the state_dict of the adapter_model. However, it’s missing keys for the v_head since it’s just a Lora adapter. How can I verify that training from the checkpoint is resuming properly with LoRA?

I am using versions

transformers             4.48.1
trl                      0.9.6
@github-actions github-actions bot added ⏳ needs more info Additional information or clarification is required to proceed ⚡ PEFT Related to PEFT 🏋 PPO Related to PPO labels Jan 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⏳ needs more info Additional information or clarification is required to proceed ⚡ PEFT Related to PEFT 🏋 PPO Related to PPO
Projects
None yet
Development

No branches or pull requests

1 participant