From d75746be70d5128e70b772a8cef407385490fbd3 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 25 Jan 2023 04:19:23 +0530 Subject: [PATCH] adding support for int8 lora training --- setup.py | 1 + src/peft/tuners/lora.py | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d4725aba9a..fedd8e33f6 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ "torch>=1.13.0", "transformers", "accelerate", + "bitsandbytes", ], extras_require=extras, classifiers=[ diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index cfd3fd8553..6a60d08144 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -25,6 +25,8 @@ import torch.nn.functional as F from transformers.pytorch_utils import Conv1D +import bitsandbytes as bnb + from ..utils import PeftConfig, PeftType, transpose @@ -118,7 +120,9 @@ def _find_and_replace(self): if any(key.endswith(target_key) for target_key in self.peft_config.target_modules): parent, target, target_name = self._get_submodules(key) bias = target.bias is not None - if isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None: + if isinstance(target, bnb.nn.Linear8bitLt) and self.peft_config.enable_lora is None: + new_module = Linear8bitLt(target.in_features, target.out_features, **kwargs) + elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None: new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs) elif self.peft_config.enable_lora is not None: kwargs.update({"enable_lora": self.peft_config.enable_lora}) @@ -358,3 +362,38 @@ def forward(self, x: torch.Tensor): after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1) result += self.zero_pad(after_B) * self.scaling return result + + +class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer): + # Lora implemented in a dense layer + def __init__( + self, + in_features, + out_features, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + bnb.nn.Linear8bitLt.__init__(self, in_features, out_features, kwargs.get("bias", None)) + LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Linear(in_features, r, bias=False) + self.lora_B = nn.Linear(r, out_features, bias=False) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + + def reset_parameters(self): + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B.weight) + + def forward(self, x: torch.Tensor): + result = super().forward(x) + if self.r > 0: + result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + return result