Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Multiple adapters with bnb layers #1243

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,20 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
)
new_module = SVDLinear8bitLt(target, adapter_name, **kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
"compute_dtype": target_base_layer.compute_dtype,
"compress_statistics": target_base_layer.weight.compress_statistics,
"quant_type": target_base_layer.weight.quant_type,
}
)
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
Expand Down
14 changes: 7 additions & 7 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
)
new_module = Linear8bitLt(target, adapter_name, is_feedforward=is_feedforward, **eightbit_kwargs)
elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
"compute_dtype": target_base_layer.compute_dtype,
"compress_statistics": target_base_layer.weight.compress_statistics,
"quant_type": target_base_layer.weight.quant_type,
}
)
new_module = Linear4bit(target, adapter_name, is_feedforward=is_feedforward, **fourbit_kwargs)
Expand Down
20 changes: 15 additions & 5 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ def _create_and_replace(
if quantization_config is not None:
kwargs["gptq_quantization_config"] = quantization_config

linear_types = (Linear,)
if is_bnb_available():
from .bnb import Linear8bitLt

linear_types += (Linear8bitLt,)
if is_bnb_4bit_available():
from .bnb import Linear4bit

linear_types += (Linear4bit,)

# TODO: better deal with that
if isinstance(target, Conv2d):
target.update_layer_conv2d(
Expand All @@ -180,7 +190,7 @@ def _create_and_replace(
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
elif isinstance(target, Linear):
elif isinstance(target, linear_types):
target.update_layer(
adapter_name,
r,
Expand Down Expand Up @@ -284,15 +294,15 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
"compute_dtype": target_base_layer.compute_dtype,
"compress_statistics": target_base_layer.weight.compress_statistics,
"quant_type": target_base_layer.weight.quant_type,
}
)
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.weight = target.qweight
target.qweight = target_base_layer.qweight
elif isinstance(target_base_layer, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
Expand Down
121 changes: 117 additions & 4 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import torch
import torch.nn.functional as F
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
Expand All @@ -31,6 +32,7 @@
)

from peft import (
AdaLoraConfig,
AdaptionPromptConfig,
IA3Config,
LoraConfig,
Expand Down Expand Up @@ -186,17 +188,128 @@ def test_ia3_bnb_8bit_quantization(self):
@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
def test_lora_bnb_4bit_quantization_from_pretrained_safetensors(self):
@parameterized.expand(["4bit", "8bit"])
def test_lora_bnb_quantization_from_pretrained_safetensors(self, quantization):
r"""
Test that tests if the 4bit quantization using LoRA works as expected with safetensors weights.
Tests that the bnb quantization using LoRA works as expected with safetensors weights.
"""
model_id = "facebook/opt-350m"
peft_model_id = "ybelkada/test-st-lora"
kwargs = {"device_map": "auto"}
if quantization == "4bit":
kwargs["load_in_4bit"] = True
else:
kwargs["load_in_8bit"] = True

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
model = PeftModel.from_pretrained(model, peft_model_id)

_ = model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# loading a 2nd adapter works, #1239
model.load_adapter(peft_model_id, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
@parameterized.expand(["4bit", "8bit"])
def test_adalora_bnb_quantization_from_pretrained_safetensors(self, quantization):
r"""
Tests that the bnb quantization using AdaLora works as expected with safetensors weights.
"""
model_id = "facebook/opt-350m"
kwargs = {"device_map": "auto"}
if quantization == "4bit":
kwargs["load_in_4bit"] = True
else:
kwargs["load_in_8bit"] = True

model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
config = AdaLoraConfig(task_type=TaskType.CAUSAL_LM)
peft_model = get_peft_model(model, config)
peft_model = prepare_model_for_kbit_training(peft_model)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(peft_model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
@parameterized.expand(["4bit", "8bit"])
def test_ia3_bnb_quantization_from_pretrained_safetensors(self, quantization):
r"""
Tests that the bnb quantization using IA³ works as expected with safetensors weights.
"""
model_id = "facebook/opt-350m"
kwargs = {"device_map": "auto"}
if quantization == "4bit":
kwargs["load_in_4bit"] = True
else:
kwargs["load_in_8bit"] = True

model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
config = IA3Config(task_type=TaskType.CAUSAL_LM)
peft_model = get_peft_model(model, config)
peft_model = prepare_model_for_kbit_training(peft_model)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

@pytest.mark.single_gpu_tests
def test_lora_gptq_quantization_from_pretrained_safetensors(self):
r"""
Tests that the autogptq quantization using LoRA works as expected with safetensors weights.
"""
from transformers import GPTQConfig

model_id = "marcsun13/opt-350m-gptq-4bit"
quantization_config = GPTQConfig(bits=4, use_exllama=False)
kwargs = {
"pretrained_model_name_or_path": model_id,
"torch_dtype": torch.float16,
"device_map": "auto",
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = prepare_model_for_kbit_training(model)

config = LoraConfig(task_type="CAUSAL_LM")
peft_model = get_peft_model(model, config)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

@require_bitsandbytes
@pytest.mark.multi_gpu_tests
Expand Down
Loading