Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4-bit support to IA3 - Outperforms QLoRA in both speed and memory consumption #864

Merged
merged 13 commits into from
Sep 26, 2023
7 changes: 6 additions & 1 deletion src/peft/tuners/ia3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from peft.import_utils import is_bnb_available
from peft.import_utils import is_bnb_4bit_available, is_bnb_available

from .config import IA3Config
from .layer import IA3Layer, Linear
Expand All @@ -27,3 +27,8 @@
from .bnb import Linear8bitLt

__all__ += ["Linear8bitLt"]

if is_bnb_4bit_available():
from .bnb import Linear4bit

__all__ += ["Linear4bit"]
170 changes: 118 additions & 52 deletions src/peft/tuners/ia3/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,127 @@
import bitsandbytes as bnb
import torch

from peft.import_utils import is_bnb_4bit_available, is_bnb_available

from .layer import IA3Layer


class Linear8bitLt(bnb.nn.Linear8bitLt, IA3Layer):
# (IA)^3 implemented in a dense layer
def __init__(
self,
adapter_name,
in_features,
out_features,
is_feedforward,
**kwargs,
) -> None:
bnb.nn.Linear8bitLt.__init__(
if is_bnb_available():

class Linear8bitLt(bnb.nn.Linear8bitLt, IA3Layer):
# (IA)^3 implemented in a dense layer
def __init__(
self,
adapter_name,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward

# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

init_ia3_weights = kwargs.pop("init_ia3_weights", True)
self.update_layer(adapter_name, init_ia3_weights)
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.disable_adapters:
return super().forward(x)

ia3_scaling = 1
for active_adapter in self.active_adapters:
if active_adapter not in self.ia3_l.keys():
continue
ia3_scaling *= self.ia3_l[active_adapter].flatten()

requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32)
if requires_conversion:
x = x.float()
if self.is_feedforward:
result = super().forward(x * ia3_scaling)
expected_dtype = result.dtype
else:
result = super().forward(x)
expected_dtype = result.dtype
result = result * ia3_scaling

if requires_conversion:
result = result.to(expected_dtype)

return result
is_feedforward,
**kwargs,
) -> None:
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward

# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

init_ia3_weights = kwargs.pop("init_ia3_weights", True)
self.update_layer(adapter_name, init_ia3_weights)
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.disable_adapters:
return super().forward(x)

ia3_scaling = 1
for active_adapter in self.active_adapters:
if active_adapter not in self.ia3_l.keys():
continue
ia3_scaling *= self.ia3_l[active_adapter].flatten()

requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32)
if requires_conversion:
x = x.float()
if self.is_feedforward:
result = super().forward(x * ia3_scaling)
expected_dtype = result.dtype
else:
result = super().forward(x)
expected_dtype = result.dtype
result = result * ia3_scaling

if requires_conversion:
result = result.to(expected_dtype)

return result


if is_bnb_4bit_available():

class Linear4bit(bnb.nn.Linear4bit, IA3Layer):
# IA3 implemented in a dense layer
def __init__(
self,
adapter_name,
in_features,
out_features,
is_feedforward,
**kwargs,
) -> None:
bnb.nn.Linear4bit.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward

# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

init_ia3_weights = kwargs.pop("init_ia3_weights", True)
self.update_layer(adapter_name, init_ia3_weights)
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.disable_adapters:
return super().forward(x)

ia3_scaling = 1
for active_adapter in self.active_adapters:
if active_adapter not in self.ia3_l.keys():
continue
ia3_scaling *= self.ia3_l[active_adapter].flatten()

requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32)
if requires_conversion:
x = x.float()
if self.is_feedforward:
result = super().forward(x * ia3_scaling)
expected_dtype = result.dtype
else:
result = super().forward(x)
expected_dtype = result.dtype
result = result * ia3_scaling

result = result.clone()
# adalora.py and lora.py both suggest that this is necessary for 4-bit training on older versions of Pytorch.
# This has been duplicated here.

if requires_conversion:
result = result.to(expected_dtype)

return result
30 changes: 29 additions & 1 deletion src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from transformers.pytorch_utils import Conv1D

from peft.import_utils import is_bnb_available
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTuner
from peft.utils import (
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
Expand All @@ -39,6 +39,11 @@

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
import bitsandbytes as bnb
His-Wardship marked this conversation as resolved.
Show resolved Hide resolved

from .bnb import Linear4bit, Linear8bitLt


class IA3Model(BaseTuner):
"""
Expand Down Expand Up @@ -82,6 +87,7 @@ def __init__(self, model, config, adapter_name):
def _create_new_module(ia3_config, adapter_name, target, **kwargs):
bias = hasattr(target, "bias") and target.bias is not None
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
is_feedforward = kwargs.pop("is_feedforward", False)

if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
Expand All @@ -102,6 +108,23 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs):
bias=bias,
**eightbit_kwargs,
)
elif loaded_in_4bit and isinstance(target, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
}
)
new_module = Linear4bit(
adapter_name,
target.in_features,
target.out_features,
is_feedforward,
bias=bias,
**fourbit_kwargs,
)
else:
# Create a new Linear module with (IA)^3 parameters for torch.nn.Linear
# or Conv1D modules
Expand Down Expand Up @@ -156,6 +179,7 @@ def _create_and_replace(
**optional_kwargs,
):
loaded_in_8bit = optional_kwargs["loaded_in_8bit"]
loaded_in_4bit = optional_kwargs["loaded_in_4bit"]
current_key = optional_kwargs["current_key"]

# check if target module is in feedforward_modules
Expand All @@ -168,6 +192,7 @@ def _create_and_replace(
"fan_in_fan_out": ia3_config.fan_in_fan_out,
"init_ia3_weights": ia3_config.init_ia3_weights,
"loaded_in_8bit": loaded_in_8bit,
"loaded_in_4bit": loaded_in_4bit,
"is_feedforward": is_feedforward,
}

Expand Down Expand Up @@ -257,6 +282,9 @@ def merge_and_unload(self):
if getattr(self.model, "is_loaded_in_8bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode")

if getattr(self.model, "is_loaded_in_4bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode")

key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
for key in key_list:
try:
Expand Down
Loading
Loading