From 679e8000594348cb9b2f743640e1aeb8917ee372 Mon Sep 17 00:00:00 2001 From: Moshe Island Date: Tue, 12 Sep 2023 15:33:34 +0300 Subject: [PATCH] deepspeed-chat: train v_head when only optimizing lora When using only optimize lora, we still need to train the v_head parameter. Change-Id: I252c3ee69819997bf336482c6779b070f2e76df8 Signed-off-by: Moshe Island --- .../training/step2_reward_model_finetuning/main.py | 3 ++- applications/DeepSpeed-Chat/training/utils/module/lora.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index d3352ce3d..563c3e9a9 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -252,7 +252,8 @@ def main(): args.lora_module_name, args.lora_dim) if args.only_optimize_lora: - rm_model = only_optimize_lora_parameters(rm_model) + rm_model = only_optimize_lora_parameters( + rm_model, force_optimize_params=['v_head.weight']) rm_model = make_model_gradient_checkpointing_compatible(rm_model) train_phase = 2 diff --git a/applications/DeepSpeed-Chat/training/utils/module/lora.py b/applications/DeepSpeed-Chat/training/utils/module/lora.py index cd37e6496..32c9730b6 100644 --- a/applications/DeepSpeed-Chat/training/utils/module/lora.py +++ b/applications/DeepSpeed-Chat/training/utils/module/lora.py @@ -131,10 +131,10 @@ def convert_lora_to_linear_layer(model): return model -def only_optimize_lora_parameters(model): +def only_optimize_lora_parameters(model, force_optimize_params=[]): # turn off the gradient of all the parameters except the LoRA parameters for name, param in model.named_parameters(): - if "lora_right_weight" in name or "lora_left_weight" in name: + if "lora_right_weight" in name or "lora_left_weight" in name or name in force_optimize_params: param.requires_grad = True else: param.requires_grad = False