-
Notifications
You must be signed in to change notification settings - Fork 6
/
merge_peft_adapter.py
84 lines (70 loc) · 3.51 KB
/
merge_peft_adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!python
# -*- coding: utf-8 -*-
# @author: Kun
'''
Author: Kun
Date: 2023-04-19 00:13:26
LastEditTime: 2023-04-19 00:16:00
LastEditors: Kun
Description:
FilePath: /Alpaca-RLHF-PyTorch/merge_peft_adapter.py
'''
from dataclasses import dataclass, field
from typing import Optional
import peft
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from transformers import LlamaForCausalLM, LlamaTokenizer
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable
model_name: Optional[str] = field(default="./lora-alpaca", metadata={"help": "the model name"})
# base_model_name: Optional[str] = field(default="decapoda-research/llama-7b-hf", metadata={"help": "the model name"}) # my code: this is not used.
output_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
print("script_args: ", script_args)
peft_model_id = script_args.model_name
peft_config = PeftConfig.from_pretrained(peft_model_id)
print("peft_config: ", peft_config)
model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
# ValueError: Loading THUDM/chatglm-6b requires you to execute the configuration file in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set the option `trust_remote_code=True` to remove this error.
# trust_remote_code=True,
)
# tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
# using above code, it will raise exception "ValueError: Tokenizer class LLaMATokenizer does not exist or is not currently imported."
# reference https://github.com/huggingface/transformers/issues/22222
# Hi @candowu, thanks for raising this issue. This is arising, because the tokenizer in the config on the hub points to LLaMATokenizer. However, the tokenizer in the library is LlamaTokenizer.
# This is likely due to the configuration files being created before the final PR was merged in.
tokenizer = LlamaTokenizer.from_pretrained(peft_config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)
model.eval()
key_list = [key for key, _ in model.base_model.model.named_modules() if "lora" not in key]
print("key_list: ", key_list)
for key in key_list:
print("key: ", key)
# peft==0.2.0 work
parent, target, target_name = model.base_model._get_submodules(key)
# peft==0.3.0.dev0 class has no method _get_submodules, use code below, other error. WTF!
# from peft.tuners.lora import _get_submodules
# parent, target, target_name = _get_submodules(model.base_model, key)
if isinstance(target, peft.tuners.lora.Linear):
bias = target.bias is not None
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
model.base_model._replace_module(parent, target_name, new_module, target)
model = model.base_model.model
if script_args.output_name is None:
output_name = f"{script_args.model_name}-adapter-merged"
model.save_pretrained(output_name)
else:
model.save_pretrained(f"{script_args.output_name}")
# model.push_to_hub(f"{script_args.model_name}-adapter-merged", use_temp_dir=False)