-
Notifications
You must be signed in to change notification settings - Fork 5
/
peft_tuners_lora.py
executable file
·191 lines (167 loc) · 8.06 KB
/
peft_tuners_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import math
import re
import torch
import warnings
from peft.tuners import lora
from peft.tuners.lora import Linear, LoraLayer
from peft.utils import _get_submodules, PeftType
from torch import nn
from transformers.pytorch_utils import Conv1D
from model import QuantLinear
class LinearLowbitLt(QuantLinear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
adapter_name,
in_features,
out_features,
groupsize: int = -1,
bits = 4,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
QuantLinear.__init__(
self,
in_features,
out_features,
groupsize,
bits,
)
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.qweight.requires_grad = False
self.qscales.requires_grad = False
self.qscales_scales.requires_grad = False
self.qscales_zeros.requires_grad= False
self.qzeros.requires_grad = False
self.g_idx.requires_grad = False
self.bias.requires_grad = False
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
).to(expected_dtype)
* self.scaling[self.active_adapter]
)
else:
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
)
* self.scaling[self.active_adapter]
)
result += output
return result
@property
def weight(self):
class WeightDeviceClass:
device = self.qweight.device
return WeightDeviceClass()
class LoraModel(lora.LoraModel):
def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = target.bias is not None
if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
if isinstance(target, QuantLinear):
new_module = LinearLowbitLt(adapter_name, target.in_features, target.out_features, target.groupsize, target.bits, bias=bias, **kwargs)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
)
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
if isinstance(old_module, QuantLinear) and isinstance(new_module, LinearLowbitLt):
new_module.qweight = old_module.qweight
new_module.qscales = old_module.qscales
new_module.qscales_scales = old_module.qscales_scales
new_module.qscales_zeros = old_module.qscales_zeros
new_module.qzeros = old_module.qzeros
new_module.g_idx = old_module.g_idx
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.qweight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.qweight.device)
else:
new_module.weight = old_module.weight
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
def replace_peft_model_with_lora_model():
import peft.peft_model
peft.peft_model.PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = LoraModel