Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dpo training with Lora can not save fine-tuned weights #742

Closed
LuJunru opened this issue Sep 6, 2023 · 13 comments · Fixed by #956
Closed

dpo training with Lora can not save fine-tuned weights #742

LuJunru opened this issue Sep 6, 2023 · 13 comments · Fixed by #956

Comments

@LuJunru
Copy link

LuJunru commented Sep 6, 2023

issue

The following script manage to train and save. However, the saved weights are incorrect.

Training script

def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # load config and tokenziers
    config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
    config.use_cache = False
    tokenizer = LlamaTokenizer.from_pretrained(model_args.model_name_or_path, truncation_side='left')

    # initialize modules
    model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
    model.enable_input_require_grads()

    # add pad token in tokenizer if needed
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token":"<pad>"})
        tokenizer.pad_token_id = 0

    # Setup seed
    set_seed(training_args.seed)
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))

    # Setup Trainer
    training_args = training_args.to_dict()
    training_args |= {'remove_unused_columns': False}
    training_args = TrainingArguments(**training_args)
    peft_config = LoraConfig(
        r=model_args.lora_r,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        target_modules=[
            "q_proj",
            "v_proj",
            "k_proj",
            "out_proj",
            "fc_in",
            "fc_out",
            "wte",
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )
    model_peft = get_peft_model(model, peft_config)
    trainer = DPOTrainer(
        model=model_peft,
        ref_model=None,
        beta=0.1, # DPO temprature
        train_dataset=prepared_dataset["train"],
        eval_dataset=prepared_dataset["eval"],
        tokenizer=tokenizer,
        args=training_args,
        peft_config=peft_config,
        max_length=data_args.model_max_length,
        max_prompt_length=int(data_args.model_max_length) * 3 // 4,
    )

    # Training
    train_result = trainer.train()
    trainer.save_state()
    trainer.save_model()

Training output

{'loss': 0.6934, 'learning_rate': 0.0, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -63.125, 'logps/chosen': -84.625, 'logits/rejected': -1.3544921875, 'logits/chosen': -1.353515625, 'epoch': 0.0}
{'loss': 0.6933, 'learning_rate': 0.0005, 'rewards/chosen': 0.0263671875, 'rewards/rejected': 0.0164031982421875, 'rewards/accuracies': 0.5625, 'rewards/margins': 0.00995635986328125, 'logps/rejected': -55.625, 'logps/chosen': -38.59375, 'logits/rejected': -1.24609375, 'logits/chosen': -1.2490234375, 'epoch': 0.0}
{'loss': 0.6938, 'learning_rate': 0.0005, 'rewards/chosen': 0.0035152435302734375, 'rewards/rejected': 0.006542205810546875, 'rewards/accuracies': 0.5, 'rewards/margins': -0.0030307769775390625, 'logps/rejected': -66.75, 'logps/chosen': -65.0625, 'logits/rejected': -1.296875, 'logits/chosen': -1.2890625, 'epoch': 0.0}
{'loss': 0.6814, 'learning_rate': 0.0003333333333333333, 'rewards/chosen': 0.05535888671875, 'rewards/rejected': -0.007030487060546875, 'rewards/accuracies': 0.75, 'rewards/margins': 0.062408447265625, 'logps/rejected': -70.9375, 'logps/chosen': -40.875, 'logits/rejected': -1.255859375, 'logits/chosen': -1.259765625, 'epoch': 0.01}
{'loss': 0.6477, 'learning_rate': 0.00016666666666666666, 'rewards/chosen': -0.0174407958984375, 'rewards/rejected': -0.07177734375, 'rewards/accuracies': 0.5, 'rewards/margins': 0.054351806640625, 'logps/rejected': -65.75, 'logps/chosen': -54.28125, 'logits/rejected': -1.2783203125, 'logits/chosen': -1.28515625, 'epoch': 0.01}

Weight check when doing merge

# initialize modules
model = LlamaForCausalLM.from_pretrained(script_args.base_model_name, config=config)

print("-" * 20 + "Weight before merge" + "-" * 20)
print(model.get_output_embeddings().weight)

# Load the Lora model
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
model.eval()

# check Lora weights
print("-" * 20 + "Check Lora Weights" + "-" * 20)
print(model.model.model.layers[0].self_attn.q_proj.lora_B.default.weight)

# merge lora weight and base model
model = model.merge_and_unload()

print("-" * 20 + "Weight after merge" + "-" * 20)
print(model.get_output_embeddings().weight)

Weight check outputs

--------------------Weight before merge--------------------
Parameter containing:
tensor([[-0.0027,  0.0020, -0.0072,  ...,  0.0034, -0.0074,  0.0074],
        [-0.0315,  0.0452, -0.0030,  ..., -0.0226,  0.0144,  0.0317],
        [-0.0127,  0.0016,  0.0189,  ..., -0.0264,  0.0157, -0.0071],
        ...,
        [ 0.0199,  0.0242,  0.0271,  ...,  0.0052, -0.0103, -0.0067],
        [ 0.0074, -0.0048,  0.0076,  ..., -0.0273, -0.0171,  0.0308],
        [ 0.0192,  0.0271,  0.0170,  ..., -0.0015, -0.0046, -0.0046]],
       requires_grad=True)
--------------------Check Lora Weights--------------------
Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
--------------------Weight after merge--------------------
Parameter containing:
tensor([[-0.0027,  0.0020, -0.0072,  ...,  0.0034, -0.0074,  0.0074],
        [-0.0315,  0.0452, -0.0030,  ..., -0.0226,  0.0144,  0.0317],
        [-0.0127,  0.0016,  0.0189,  ..., -0.0264,  0.0157, -0.0071],
        ...,
        [ 0.0199,  0.0242,  0.0271,  ...,  0.0052, -0.0103, -0.0067],
        [ 0.0074, -0.0048,  0.0076,  ..., -0.0273, -0.0171,  0.0308],
        [ 0.0192,  0.0271,  0.0170,  ..., -0.0015, -0.0046, -0.0046]])

I think even if the Lora fine-tuning is not convergent, the saved adpater value should not be zero?

@Moyhub
Copy link

Moyhub commented Sep 8, 2023

I have the same question~

@lvwerra
Copy link
Member

lvwerra commented Sep 8, 2023

Maybe @younesbelkada and @kashif could have a look.

@kashif
Copy link
Collaborator

kashif commented Sep 8, 2023

@LuJunru did you try to save the model via:

model_peft.save_pretrained("peft_checkpoint")

@Moyhub
Copy link

Moyhub commented Sep 8, 2023

@kashif In fact, I had the same problem when training DPO (without accelerate). I believe the code here is what you described, but it doesn't work
image
self.model.save_pretrained(output_dir)

@LuJunru
Copy link
Author

LuJunru commented Sep 8, 2023

@kashif @Moyhub @lvwerra Hi Guys, thank you for the feedback. When i directly check the weights in adapter bin, i found the weights was saved. However, the weight key there was base_model.model.base_model.model.xxxx, this maybe related to the deepspeed wrapper. I tried to rename the key to base_model.model.xxxx, the merging succeeded. You may have a check as well.

@lvwerra
Copy link
Member

lvwerra commented Sep 8, 2023

@younesbelkada something we could also test with #724.

@AntoineBlanot
Copy link

It seems like when you feed a PEFTModel and a peft_config to the DPOTrainer, the model to be trained gets nested into 2 PEFTModel so this is the reason why we get base_model.model.base_model.model.xxx as keys in the saved model and not base_model.model.xxx.
In that case, to load the saved DPO model we must create a PEFTModel on top of the PEFTModel that was used as a reference model. Doing this, I managed to load the weights correctly.
Also, when merging the weights, you must call it twice as follows:
merged_model = dpo_model.merge_and_unload().merge_and_unload()

@iampushpdeep
Copy link

It seems like when you feed a PEFTModel and a peft_config to the DPOTrainer, the model to be trained gets nested into 2 PEFTModel so this is the reason why we get base_model.model.base_model.model.xxx as keys in the saved model and not base_model.model.xxx. In that case, to load the saved DPO model we must create a PEFTModel on top of the PEFTModel that was used as a reference model. Doing this, I managed to load the weights correctly. Also, when merging the weights, you must call it twice as follows: merged_model = dpo_model.merge_and_unload().merge_and_unload()

Can you provide a code example for your solution?

@KD-12
Copy link

KD-12 commented Oct 23, 2023

Seeing same issue. Model performance of training same as original model due to this even after dpo training. Any conclusion for this bug?
@lvwerra can you point to any other similar thread which has a solution for this ?

@Elfsong
Copy link

Elfsong commented Nov 5, 2023

Same issue here

@Elfsong
Copy link

Elfsong commented Nov 5, 2023

@kashif @Moyhub @lvwerra Hi Guys, thank you for the feedback. When i directly check the weights in adapter bin, i found the weights was saved. However, the weight key there was base_model.model.base_model.model.xxxx, this maybe related to the deepspeed wrapper. I tried to rename the key to base_model.model.xxxx, the merging succeeded. You may have a check as well.

Could you share the example code? Thanks.

@kashif
Copy link
Collaborator

kashif commented Nov 5, 2023

I have a PR that fixes this issue by merging the initial peft adapter if the trainer gets an additional peft_config. Can you kindly try it out?

@hanyin88
Copy link

Thanks for the great discussion. Just wish to mention there's subtle bug/nuance with the most recent PR.

As the DPO trainer has merged original base_model and LoRA with merge_and_unload(), the final saved LoRA adapter weights is based on the merged model, NOT the original base_model...This means that at inference time, one should merge the base_model with the original LoRA weight first, before loading the RLHF weight.

Personally I find it to be an important detail and took me a while to figure it out. Might be a good idea to highlight in future DPO tutorials.

Many thanks.

simonmeoni added a commit to arkhn/open-nlp that referenced this issue Nov 4, 2024
simonmeoni added a commit to arkhn/open-nlp that referenced this issue Nov 4, 2024
simonmeoni added a commit to arkhn/open-nlp that referenced this issue Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants