Skip to content

Commit

Permalink
adding support for int8 lora training
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 committed Jan 24, 2023
1 parent ff8a5b9 commit d75746b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"torch>=1.13.0",
"transformers",
"accelerate",
"bitsandbytes",
],
extras_require=extras,
classifiers=[
Expand Down
41 changes: 40 additions & 1 deletion src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D

import bitsandbytes as bnb

from ..utils import PeftConfig, PeftType, transpose


Expand Down Expand Up @@ -118,7 +120,9 @@ def _find_and_replace(self):
if any(key.endswith(target_key) for target_key in self.peft_config.target_modules):
parent, target, target_name = self._get_submodules(key)
bias = target.bias is not None
if isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
if isinstance(target, bnb.nn.Linear8bitLt) and self.peft_config.enable_lora is None:
new_module = Linear8bitLt(target.in_features, target.out_features, **kwargs)
elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif self.peft_config.enable_lora is not None:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
Expand Down Expand Up @@ -358,3 +362,38 @@ def forward(self, x: torch.Tensor):
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result


class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
bnb.nn.Linear8bitLt.__init__(self, in_features, out_features, kwargs.get("bias", None))
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()

def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)

def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.r > 0:
result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
return result

0 comments on commit d75746b

Please sign in to comment.