Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Refactor LoRA bnb layers for faster initialization #994

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 63 additions & 73 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,21 @@

if is_bnb_available():

class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
class Linear8bitLt(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
adapter_name,
in_features,
out_features,
base_layer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
) -> None:
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
super().__init__()
LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features)
self.base_layer = base_layer

# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.set_adapter(adapter_name)
Expand All @@ -71,6 +60,7 @@ def merge(self, safe_merge: bool = False):
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -79,37 +69,38 @@ def merge(self, safe_merge: bool = False):
)
lora_data = self.get_delta_weight(active_adapter)

if self.state.SCB is None:
self.state.SCB = self.weight.SCB
weight = self.base_layer.weight
state = self.base_layer.state
if state.SCB is None:
state.SCB = weight.SCB

# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
# dequantization directly
im = torch.eye(self.weight.data.shape[-1]).contiguous().half().to(self.weight.device)
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()

if self.state.CxB is None:
self.state.CxB, self.state.SB = bnb.functional.transform(
self.weight.data, to_order=self.state.formatB
)
out32, Sout32 = bnb.functional.igemmlt(im, self.state.CxB, Sim, self.state.SB)
output = bnb.functional.mm_dequant(out32, Sout32, SCim, self.state.SCB, bias=None).t()
w_data = output.to(lora_data.dtype).to(lora_data.device) + lora_data

if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

self.weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=self.weight.has_fp16_weights
).to(self.weight.device)
self.state.reset_grads()
self.base_layer.weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
state.reset_grads()
self.merged_adapters.append(active_adapter)

def unmerge(self):
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
Expand All @@ -119,23 +110,24 @@ def unmerge(self):
)
lora_data = self.get_delta_weight(active_adapter)

if self.state.SCB is None:
self.state.SCB = self.weight.SCB
im = torch.eye(self.weight.data.shape[-1]).contiguous().half().to(self.weight.device)
weight = self.base_layer.weight
state = self.base_layer.state
if state.SCB is None:
state.SCB = weight.SCB
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")

if self.state.CxB is None:
self.state.CxB, self.state.SB = bnb.functional.transform(
self.weight.data, to_order=self.state.formatB
)
out32, Sout32 = bnb.functional.igemmlt(im, self.state.CxB, Sim, self.state.SB)
output = bnb.functional.mm_dequant(out32, Sout32, SCim, self.state.SCB, bias=None).t()
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
output = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()

w_data = output.to(lora_data.dtype).to(lora_data.device) - lora_data
self.weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=self.weight.has_fp16_weights
).to(self.weight.device)
self.state.reset_grads()
self.base_layer.weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
state.reset_grads()

def get_delta_weight(self, adapter):
return (
Expand All @@ -146,15 +138,15 @@ def get_delta_weight(self, adapter):
* self.scaling[adapter]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = super().forward(x)
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = super().forward(x)
result = self.base_layer(x, *args, **kwargs)
else:
result = super().forward(x)
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
Expand All @@ -180,31 +172,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

if is_bnb_4bit_available():

class Linear4bit(bnb.nn.Linear4bit, LoraLayer):
class Linear4bit(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
adapter_name,
in_features,
out_features,
base_layer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
) -> None:
bnb.nn.Linear4bit.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
LoraLayer.__init__(self, in_features=in_features, out_features=out_features)

# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
super().__init__()
LoraLayer.__init__(self, in_features=base_layer.in_features, out_features=base_layer.out_features)
self.base_layer = base_layer

init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
Expand All @@ -225,39 +206,48 @@ def merge(self, safe_merge: bool = False):
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
f"You are now additionally merging {','.join(self.active_adapters)}."
)

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Merge lora module to 4-bit linear may get different generations due to rounding errors."
)
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
kwargs = self.weight.__dict__
weight = self.base_layer.weight
kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data

w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + lora_data
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device)

self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)
self.merged_adapters.append(active_adapter)

def unmerge(self):
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return

while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
continue
warnings.warn(
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
)
kwargs = self.weight.__dict__
weight = self.base_layer.weight
kwargs = weight.__dict__
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) - lora_data
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device)
w_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) - lora_data
self.base_layer.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(
weight.device
)

def get_delta_weight(self, adapter):
return (
Expand All @@ -268,15 +258,15 @@ def get_delta_weight(self, adapter):
* self.scaling[adapter]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = super().forward(x)
result = self.base_layer.forward(x, *args, **kwargs)
elif self.merged:
result = super().forward(x)
result = self.base_layer.forward(x, *args, **kwargs)
else:
result = super().forward(x)
result = self.base_layer.forward(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
Expand Down
52 changes: 31 additions & 21 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,24 @@ def _replace_module(parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias

# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer
elif hasattr(child, "quant_linear_module"):
child = child.quant_linear_module

# TODO: layers with base_layer don't need the weight to be copied, as they have a reference already
if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias

if getattr(child, "state", None) is not None:
new_module.state = child.state
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)

# dispatch to correct device
Expand Down Expand Up @@ -256,9 +268,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
"index": target.index,
}
)
new_module = Linear8bitLt(
adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
)
new_module = Linear8bitLt(adapter_name, target, **eightbit_kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
Expand All @@ -268,7 +278,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
"quant_type": target.weight.quant_type,
}
)
new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs)
new_module = Linear4bit(adapter_name, target, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
new_module = QuantLinear(adapter_name, target, **kwargs)
target.weight = target.qweight
Expand Down Expand Up @@ -389,28 +399,28 @@ def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False, sa
padding=target.padding,
dilation=target.dilation,
)
elif is_bnb_available() and isinstance(target, bnb.nn.Linear8bitLt):
bias = target.bias is not None
elif is_bnb_available() and isinstance(target, Linear8bitLt):
bias = target.base_layer.bias is not None
new_module = bnb.nn.Linear8bitLt(
target.in_features,
target.out_features,
bias=bias,
has_fp16_weights=target.state.has_fp16_weights,
memory_efficient_backward=target.state.memory_efficient_backward,
threshold=target.state.threshold,
index=target.index,
device=target.weight.device,
has_fp16_weights=target.base_layer.state.has_fp16_weights,
memory_efficient_backward=target.base_layer.state.memory_efficient_backward,
threshold=target.base_layer.state.threshold,
index=target.base_layer.index,
device=target.base_layer.weight.device,
)
elif is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
bias = target.bias is not None
elif is_bnb_4bit_available() and isinstance(target, Linear4bit):
bias = target.base_layer.bias is not None
new_module = bnb.nn.Linear4bit(
target.in_features,
target.out_features,
bias=bias,
compute_dtype=target.compute_dtype,
compress_statistics=target.weight.compress_statistics,
quant_type=target.weight.quant_type,
device=target.weight.device,
compute_dtype=target.base_layer.compute_dtype,
compress_statistics=target.base_layer.weight.compress_statistics,
quant_type=target.base_layer.weight.quant_type,
device=target.base_layer.weight.device,
)
else:
bias = target.bias is not None
Expand Down
12 changes: 4 additions & 8 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,12 +545,8 @@ def test_8bit_merge_and_disable_lora(self):
self.assertFalse(torch.allclose(out_base, out_before, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(out_base, out_after, atol=atol, rtol=rtol))
self.assertTrue(isinstance(model, PeftModel))
self.assertTrue(
isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt)
)
self.assertTrue(
isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear8bitLt)
)
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt))

@require_torch_gpu
@pytest.mark.single_gpu_tests
Expand Down Expand Up @@ -633,5 +629,5 @@ def test_4bit_merge_and_disable_lora(self):
self.assertFalse(torch.allclose(out_base, out_before, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(out_base, out_after, atol=atol, rtol=rtol))
self.assertTrue(isinstance(model, PeftModel))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit))