-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
bnb.py
319 lines (274 loc) · 13.7 KB
/
bnb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Optional
import bitsandbytes as bnb
import torch
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.utils.other import transpose
from .layer import LoraLayer
if is_bnb_available():
class Linear8bitLt(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
if self.merged:
warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Merge lora module to 8-bit linear may get different generations due to rounding errors."
)
lora_data = self.get_delta_weight(active_adapter)
weight = self.get_base_layer().weight
state = self.get_base_layer().state
if state.SCB is None:
state.SCB = weight.SCB
# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
# dequantization directly
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
w_data = output.to(lora_data.dtype).to(lora_data.device) + lora_data
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
state.reset_grads()
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Unmerge lora module to 8-bit linear may get different generations due to rounding errors."
)
lora_data = self.get_delta_weight(active_adapter)
weight = self.get_base_layer().weight
state = self.get_base_layer().state
if state.SCB is None:
state.SCB = weight.SCB
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
state.reset_grads()
def get_delta_weight(self, adapter):
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
)
* self.scaling[adapter]
)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result += output
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
if is_bnb_4bit_available():
class Linear4bit(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
if self.merged:
warnings.warn(
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)
if adapter_names is None:
adapter_names = self.active_adapters
for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Merge lora module to 4-bit linear may get different generations due to rounding errors."
)
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
weight = self.get_base_layer().weight
kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + lora_data
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
)
weight = self.get_base_layer().weight
kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)
def get_delta_weight(self, adapter):
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
)
* self.scaling[adapter]
)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result += output
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep