Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when converting PiSSA adapter to normal LoRA in deepspeed stage 3 mode #5122

Closed
1 task done
flymark2010 opened this issue Aug 8, 2024 · 4 comments
Closed
1 task done
Labels
solved This problem has been already solved

Comments

@flymark2010
Copy link

flymark2010 commented Aug 8, 2024

Reminder

  • I have read the README and searched the existing issues.

System Info

  • llamafactory version: 0.8.4.dev0
  • Platform: Linux-5.4.0-100-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version: 2.3.0a0+ebedce2 (GPU)
  • Transformers version: 4.43.3
  • Datasets version: 2.18.0
  • Accelerate version: 0.33.0
  • PEFT version: 0.11.1
  • TRL version: 0.9.6
  • GPU type: NVIDIA A800-SXM4-80GB
  • DeepSpeed version: 0.14.4

Reproduction

  1. Use scripts/pissa_init.py to init pissa adapter, and got residual model and pissa adapter:
  2. Modify pissa training examples/extras/pissa/llama3_lora_sft.yaml as below:
### model
model_name_or_path: /workspace/project/LLaMA-Factory/models/qwen2_7b_lora_pissa_init_nter16_r128_all/
adapter_name_or_path: /workspace/project/LLaMA-Factory/models/qwen2_7b_lora_pissa_init_nter16_r128_all/pissa_init

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
pissa_init: false
pissa_convert: true
deepspeed: examples/deepspeed/ds_z3_config.json

### dataset
...
  1. Training with command: llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
  2. The training progress is normal, but exception happens when converting pissa to normal LoRA:
/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /workspace/project/LLaMA-Factory/models/qwen2_7b_lora_pissa_init_nter16_r128_all - will assume that the vocabulary was not modified.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/peft/peft_model.py:230: UserWarning: `convert_pissa_to_lora` only works for converting a PiSSA adapter to a LoRA adapter
  warnings.warn("`convert_pissa_to_lora` only works for converting a PiSSA adapter to a LoRA adapter")
Traceback (most recent call last):
  File "/workspace/project/LLaMA-Factory/src/llamafactory/launcher.py", line 23, in <module>
    launch()
  File "/workspace/project/LLaMA-Factory/src/llamafactory/launcher.py", line 19, in launch
    run_exp()
  File "/workspace/project/LLaMA-Factory/src/llamafactory/train/tuner.py", line 50, in run_exp
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/workspace/project/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 94, in run_sft
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2438, in _inner_training_loop
    self.control = self.callback_handler.on_train_end(args, self.state, self.control)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer_callback.py", line 463, in on_train_end
    return self.call_event("on_train_end", args, state, control)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer_callback.py", line 507, in call_event
    result = getattr(callback, event)(
  File "/workspace/project/LLaMA-Factory/src/llamafactory/train/callbacks.py", line 163, in on_train_end
    model.save_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 283, in save_pretrained
    output_state_dict = save_pissa_as_lora(
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 232, in save_pissa_as_lora
    self.load_adapter(
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 988, in load_adapter
    load_result = set_peft_model_state_dict(
  File "/usr/local/lib/python3.10/dist-packages/peft/utils/save_and_load.py", line 353, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
        size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_A.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([128, 3584]).
        size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_B.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3584, 128]).
        size mismatch for base_model.model.model.layers.0.self_attn.k_proj.lora_A.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([128, 3584]).
        size mismatch for base_model.model.model.layers.0.self_attn.k_proj.lora_B.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 128]).
        size mismatch for base_model.model.model.layers.0.self_attn.v_proj.lora_A.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([128, 3584]).
        size mismatch for base_model.model.model.layers.0.self_attn.v_proj.lora_B.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([512, 128]).
        size mismatch for base_model.model.model.layers.0.self_attn.o_proj.lora_A.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([128, 3584]).
        size mismatch for base_model.model.model.layers.0.self_attn.o_proj.lora_B.pissa_init.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3584, 128]).
...
  1. If change deepspeed config to stage 0, no error happens.

Expected behavior

PiSSA adapter can be converted to normal LoRA adapter in deepspeed stage 3 mode.

Others

No response

@github-actions github-actions bot added the pending This problem is yet to be addressed label Aug 8, 2024
@hiyouga
Copy link
Owner

hiyouga commented Aug 9, 2024

try pissa_convert: false

@hiyouga
Copy link
Owner

hiyouga commented Aug 9, 2024

@BenjaminBossan There may be a bug in save_pissa_as_lora under deepspeed zero3, has it been fixed in the latest version?

@BenjaminBossan
Copy link

Ah yes, quite possibly we're missing gather_param_ctx when we collect the mutated state dict. Not sure if you could quickly try adding this to check if it fixes the issue?

Btw nice work with PissaConvertCallback. Just a small note regarding this line:

pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir

We have deprecated the argument name convert_pissa_to_lora in favor of path_initial_model_for_weight_conversion. First of all, this is a more sensible name, second this allows us to re-use the parameter for similar methods that are not PiSSA, e.g. OLoRA.

@hiyouga
Copy link
Owner

hiyouga commented Aug 9, 2024

@BenjaminBossan
Thanks for your prompt reply. Per your advice, we find it difficult to resolve the problem of converting pissa weights under deepspeed zero3. Usually, we only call the save function at the main process, while the gather context requires being called at all ranks. We should recommend user to do pissa convert without deepspeed zero3.

Regarding the callback, we have noticed the change and it is just backward compatibility for peft<0.12.0, as we noted in the comment :)

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Aug 9, 2024
@hiyouga hiyouga closed this as completed Aug 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

3 participants