Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

奖励模型断点续训报错 #2351

Open
zhanglv0209 opened this issue Jan 26, 2024 · 6 comments
Open

奖励模型断点续训报错 #2351

zhanglv0209 opened this issue Jan 26, 2024 · 6 comments
Labels
good first issue Good for newcomers pending This problem is yet to be addressed

Comments

@zhanglv0209
Copy link

zhanglv0209 commented Jan 26, 2024

运行出现了故障,然后重新执行

cmd

/mntenv/llama_etuning/bin/deepspeed --include localhost:4,5,6,7 --master_port=9101 src/train_bash.py
--deepspeed ds_config.json
--stage rm
--do_train
--model_name_or_path /mnodel/llama2-Chinese-7b-Chat
--dataset comparison_gpt4_zh
--template llama2
--finetuning_type lora
--lora_target q_proj,v_proj
--output_dir /llama2-Chinese-7b-Chat-20240126/
--per_device_train_batch_size 16
--gradient_accumulation_steps 16
--lr_scheduler_type cosine
--logging_steps 10
--save_steps 100
--learning_rate 1e-6
--num_train_epochs 2.0
--plot_loss
--fp16
--preprocessing_num_workers 20

error

Traceback (most recent call last):
File "/mnt/nvme0n1/zhanglv/code/LLaMA-Factory/src/train_bash.py", line 14, in
main()
File "/mnt/nvme0n1/zhanglv/code/LLaMA-Factory/src/train_bash.py", line 5, in main
run_exp()
File "/mnt/nvme0n1/zhanglv/code/LLaMA-Factory/src/llmtuner/train/tuner.py", line 33, in run_exp
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
File "/mnt/nvme0n1/zhanglv/code/LLaMA-Factory/src/llmtuner/train/rm/workflow.py", line 55, in run_rm
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
File "/mnt/nvme0n1/zhanglv/venv/llama_etuning/lib/python3.9/site-packages/transformers/trainer.py", line 1537, in train
return inner_training_loop(
File "/mnt/nvme0n1/zhanglv/venv/llama_etuning/lib/python3.9/site-packages/transformers/trainer.py", line 1693, in _inner_training_loop
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
File "/mnt/nvme0n1/zhanglv/venv/llama_etuning/lib/python3.9/site-packages/transformers/integrations/deepspeed.py", line 402, in deepspeed_load_checkpoint
load_path, _ = deepspeed_engine.load_checkpoint(
File "/mnt/nvme0n1/zhanglv/venv/llama_etuning/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 2697, in load_checkpoint
load_path, client_states = self._load_checkpoint(load_dir,
File "/mnt/nvme0n1/zhanglv/venv/llama_etuning/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 2762, in load_checkpoint
self.load_module_state_dict(checkpoint=checkpoint,
File "/mnt/nvme0n1/zhanglv/venv/llama_etuning/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 2560, in load_module_state_dict
self.module.load_state_dict(
File "/mnt/nvme0n1/zhanglv/venv/llama_etuning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AutoModelForCausalLMWithValueHead:
Missing key(s) in state_dict: "pretrained_model.base_model.model.model.embed_tokens.weight", "pretrained_model.base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight",
"pretrained_model.base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight", "pretrained_model.base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight", "pre
trained_model.base_model.model.model.layers.0.self_attn.k_proj.weight", "pretrained_model.base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight", "pretrained_model.base_model
.model.model.layers.0.self_attn.v_proj.lora_A.default.weight", "pretrained_model.base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight", "pretrained_model.base_model.mode
l.model.layers.0.self_attn.o_proj.weight", "pretrained_model.base_model.model.model.layers.0.mlp.gate_proj.weight", "pretrained_model.base_model.model.model.layers.0.mlp.up_proj.weight",
"pretrained_model.base_model.model.model.layers.0.mlp.down_proj.weight", "pretrained_model.base_model.model.model.layers.0.input_layernorm.weight", "pretrained_model.base_model.model.mode
l.layers.0.post_attention_layernorm.weight", "pretrained_model.base_model.model.model.layers.1.self_attn.q_proj.base_layer.weight", "pretrained_model.base_model.model.model.layers.1.self

attn.q_proj.lora_A.default.weight", "pretrained_model.base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight", "pretrained_model.base_model.model.model.layers.1.self_attn.
k_proj.weight", "pretrained_model.base_model.model.model.layers.1.self_attn.v_proj.base_layer.weight", "pretrained_model.base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.we
ight", "pretrained_model.base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight", "pretrained_model.base_model.model.model.layers.1.self_attn.o_proj.weight", "pretrained_m
odel.base_model.model.model.layers.1.mlp.gate_proj.weight", "pretrained_model.base_model.model.model.layers.1.mlp.up_proj.weight", "pretrained_model.base_model.model.model.layers.1.mlp.do
wn_proj.weight", "pretrained_model.base_model.model.model.layers.1.input_layernorm.weight", "pretrained_model.base_model.model.model.layers.1.post_attention_layernorm.weight", "pretrained
_model.base_model.model.model.layers.2.self_attn.q_proj.base_layer.weight", "pretrained_model.base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight", "pretrained_model.ba
se_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight", "pretrained_model.base_model.model.model.layers.2.self_attn.k_proj.weight", "pretrained_model.base_model.model.model
.layers.2.self_attn.v_proj.base_layer.weight", "pretrained_model.base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight", "pretrained_model.base_model.model.model.layers.2
.self_attn.v_proj.lora_B.default.weight", "pretrained_model.base_model.model.model.layers.2.self_attn.o_proj.weight", "pretrained_model.base_model.model.model.layers.2.mlp.gate_proj.weigh
t", "pretrained_model.base_model.model.model.layers.2.mlp.up_proj.weight", "pretrained_model.base_model.model.model.layers.2.mlp.down_proj.weight", "pretrained_model.base_model.model.mode
l.layers.2.input_layernorm.weight", "pretrained_model.base_model.model.model.layers.2.post_attention_layernorm.weight", "pretrained_model.base_model.model.model.layers.3.self_attn.q_proj.
base_layer.weight", "pretrained_model.base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight", "pretrained_model.base_model.model.model.layers.3.self_attn.q_proj.lora_B.de
fault.weight", "pretrained_model.base_model.model.model.layers.3.self_attn.k_proj.weight", "pretrained_model.base_model.model.model.layers.3.self_attn.v_proj.base_layer.weight", "pretrain
ed_model.base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight", "pretraine

@hiyouga hiyouga added the pending This problem is yet to be addressed label Jan 29, 2024
@chensimian
Copy link

我也遇到了该问题,在训练结束保存时报错,RuntimeError: Error(s) in loading state_dict for AutoModelForCausalLMWithValueHead:

@Liusifei
Copy link

Liusifei commented Feb 9, 2024

Similar observation here: While I use the originally workable script to resume, I've got the following for a SFT lora model.

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
Missing key(s) in state_dict: "base_model.model.model.embed_tokens.weight", "base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.0.self_attn.k_proj.weight", "base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.0.self_attn.o_proj.weight", "base_model.model.model.layers.0.mlp.gate_proj.weight", "base_model.model.model.layers.0.mlp.up_proj.weight", "base_model.model.model.layers.0.mlp.down_proj.weight", "base_model.model.model.layers.0.input_layernorm.weight", "base_model.model.model.layers.0.post_attention_layernorm.weight", "base_model.model.model.layers.1.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.1.self_attn.k_proj.weight", "base_model.model.model.layers.1.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.1.self_attn.o_proj.weight", "base_model.model.model.layers.1.mlp.gate_proj.weight", "base_model.model.model.layers.1.mlp.up_proj.weight", "base_model.model.model.layers.1.mlp.down_proj.weight", "base_model.model.model.layers.1.input_layernorm.weight", "base_model.model.model.layers.1.post_attention_layernorm.weight", "base_model.model.model.layers.2.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.2.self_attn.k_proj.weight", "base_model.model.model.layers.2.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.2.self_attn.o_proj.weight", "base_model.model.model.layers.2.mlp.gate_proj.weight", "base_model.model.model.layers.2.mlp.up_proj.weight", "base_model.model.model.layers.2.mlp.down_proj.weight", "base_model.model.model.layers.2.input_layernorm.weight", "base_model.model.model.layers.2.post_attention_layernorm.weight", "base_model.model.model.layers.3.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.3.self_attn.k_proj.weight", "base_model.model.model.layers.3.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.3.self_attn.o_proj.weight", "base_model.model.model.layers.3.mlp.gate_proj.weight", "base_model.model.model.layers.3.mlp.up_proj.weight", "base_model.model.model.layers.3.mlp.down_proj.weight", "base_model.model.model.layers.3.input_layernorm.weight", "base_model.model.model.layers.3.post_attention_layernorm.weight", "base_model.model.model.layers.4.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.4.self_attn.k_proj.weight", "base_model.model.model.layers.4.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.4.self_attn.o_proj.weight", "base_model.model.model.layers.4.mlp.gate_proj.weight", "base_model.model.model.layers.4.mlp.up_proj.weight", "base_model.model.model.layers.4.mlp.down_proj.weight", "base_model.model.model.layers.4.input_layernorm.weight", "base_model.model.model.layers.4.post_attention_layernorm.weight", "base_model.model.model.layers.5.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.5.self_attn.k_proj.weight", "base_model.model.model.layers.5.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.5.self_attn.o_proj.weight", "base_model.model.model.layers.5.mlp.gate_proj.weight", "base_model.model.model.layers.5.mlp.up_proj.weight", "base_model.model.model.layers.5.mlp.down_proj.weight", "base_model.model.model.layers.5.input_layernorm.weight", "base_model.model.model.layers.5.post_attention_layernorm.weight", "base_model.model.model.layers.6.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.6.self_attn.k_proj.weight", "base_model.model.model.layers.6.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.6.self_attn.o_proj.weight", "base_model.model.model.layers.6.mlp.gate_proj.weight", "base_model.model.model.layers.6.mlp.up_proj.weight", "base_model.model.model.layers.6.mlp.down_proj.weight", "base_model.model.model.layers.6.input_layernorm.weight", "base_model.model.model.layers.6.post_attention_layernorm.weight", "base_model.model.model.layers.7.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.7.self_attn.k_proj.weight", "base_model.model.model.layers.7.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.7.self_attn.o_proj.weight", "base_model.model.model.layers.7.mlp.gate_proj.weight", "base_model.model.model.layers.7.mlp.up_proj.weight", "base_model.model.model.layers.7.mlp.down_proj.weight", "base_model.model.model.layers.7.input_layernorm.weight", "base_model.model.model.layers.7.post_attention_layernorm.weight", "base_model.model.model.layers.8.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.8.self_attn.k_proj.weight", "base_model.model.model.layers.8.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.8.self_attn.o_proj.weight", "base_model.model.model.layers.8.mlp.gate_proj.weight", "base_model.model.model.layers.8.mlp.up_proj.weight", "base_model.model.model.layers.8.mlp.down_proj.weight", "base_model.model.model.layers.8.input_layernorm.weight", "base_model.model.model.layers.8.post_attention_layernorm.weight", "base_model.model.model.layers.9.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.9.self_attn.k_proj.weight", "base_model.model.model.layers.9.self_attn.v_proj.base_layer.weight", "base_model.model.model.layers.9.self_attn.o_proj.weight", "base_model.model.model.layers.9.mlp.gate_proj.weight", "base_model.model.model.layers.9.mlp.up_proj.weight", "base_model.model.model.layers.9.mlp.down_proj.weight", "base_model.model.model.layers.9.input_layernorm.weight", "base_model.model.model.layers.9.post_attention_layernorm.weight", "base_model.model.model.layers.10.self_attn.q_proj.base_layer.weight", "base_model.model.model.layers.10.self_attn.k_proj.weight",
...

@stephen-nju
Copy link
Contributor

使用deepspeed 断点续训,遇到相同问题。尝试分析下:1.AutoModelForCausalLMWithValueHead不是peftModel, 只能使用strict模式,但state_dict只包含两项 image 2.deepspeed 包含frozen_param_fragment
image

@stephen-nju
Copy link
Contributor

参考:#559 修改transformers源码 site-packages/transformers/integrations/deepspeed.py
image 可解决问题

@hiyouga hiyouga added the good first issue Good for newcomers label Mar 7, 2024
@onenotell
Copy link

借楼问一下,
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True

请问,全参训练的rm模型能否增加对resume的支持?
还有ppo支持resume相对更困难一些,是否后续会支持?

长时间训练出现问题的概率非常大,希望能支持断点恢复的功能,谢谢。

@onenotell
Copy link

参考:#559 修改transformers源码 site-packages/transformers/integrations/deepspeed.py image 可解决问题

python3.10 transformers 4.28.2 可以直接用shell修改打包,测试可用

sed -i 's/self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)/self.model_wrapped, resume_from_checkpoint, load_module_strict=False/g' /usr/local/lib/python3.10/dist-packages/transformers/trainer.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers pending This problem is yet to be addressed
Projects
None yet
Development

No branches or pull requests

6 participants