-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AQLM support for LoRA #1476
AQLM support for LoRA #1476
Changes from 4 commits
7d29b37
cfd390f
8e01436
d8b075f
dba482e
0750f48
1b2cea9
255db4f
2f7ac3b
2764745
b94b02e
fb8eadf
6fed145
d3fa2f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Optional | ||
|
||
import torch | ||
|
||
from peft.import_utils import is_aqlm_available | ||
from peft.tuners.lora.layer import LoraLayer | ||
from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
||
|
||
if is_aqlm_available(): | ||
from aqlm import QuantizedLinear as AqlmQuantizedLinear | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need the alias? There is no name conflict because the PEFT class is called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, we don't. I simply didn't like those two names being similar. |
||
|
||
|
||
class QuantLinear(torch.nn.Module, LoraLayer): | ||
def __init__( | ||
self, | ||
base_layer, | ||
adapter_name: str, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
init_lora_weights: bool = True, | ||
use_rslora: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
LoraLayer.__init__(self, base_layer) | ||
|
||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Base layer is initialized during |
||
# for backwards compatibility | ||
self.quant_linear_module = base_layer | ||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||
|
||
def forward(self, x: torch.Tensor): | ||
# note: logic differs from default Linear because merging is not supported | ||
result = self.quant_linear_module(x) | ||
|
||
if self.disable_adapters: | ||
return result | ||
|
||
for active_adapter in self.active_adapters: | ||
if active_adapter not in self.lora_A.keys(): | ||
continue | ||
lora_A = self.lora_A[active_adapter] | ||
lora_B = self.lora_B[active_adapter] | ||
dropout = self.lora_dropout[active_adapter] | ||
scaling = self.scaling[active_adapter] | ||
|
||
requires_conversion = not torch.is_autocast_enabled() | ||
if requires_conversion: | ||
expected_dtype = result.dtype | ||
x = x.to(lora_A.weight.dtype) | ||
|
||
output = lora_B(lora_A(dropout(x))) | ||
if requires_conversion: | ||
output = output.to(expected_dtype) | ||
output = output * scaling | ||
result += output | ||
return result | ||
|
||
def __repr__(self) -> str: | ||
rep = super().__repr__() | ||
return "lora." + rep | ||
|
||
# TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102 | ||
# def reset_lora_parameters(self, adapter_name): | ||
# if adapter_name in self.lora_A.keys(): | ||
# torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight) | ||
# torch.nn.init.zeros_(self.lora_B[adapter_name].weight) | ||
|
||
|
||
def dispatch_aqlm( | ||
target: torch.nn.Module, | ||
adapter_name: str, | ||
**kwargs: Any, | ||
) -> Optional[torch.nn.Module]: | ||
new_module = None | ||
|
||
if isinstance(target, BaseTunerLayer): | ||
target_base_layer = target.get_base_layer() | ||
else: | ||
target_base_layer = target | ||
|
||
if is_aqlm_available() and isinstance(target_base_layer, AqlmQuantizedLinear): | ||
new_module = QuantLinear(target, adapter_name, **kwargs) | ||
target.qweight = target_base_layer.codes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is the place where quantized linear layers get wrapped with a LoRA wrapper. |
||
|
||
return new_module |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ | |
from peft.utils import SAFETENSORS_WEIGHTS_NAME | ||
|
||
from .testing_utils import ( | ||
require_aqlm, | ||
require_auto_awq, | ||
require_auto_gptq, | ||
require_bitsandbytes, | ||
|
@@ -1383,6 +1384,149 @@ def test_model_loaded_in_float16_working(self): | |
trainer.train() | ||
|
||
|
||
@require_torch_gpu | ||
@require_aqlm | ||
class PeftAqlmGPUTests(unittest.TestCase): | ||
r""" | ||
AQLM + peft tests | ||
""" | ||
|
||
def setUp(self): | ||
self.causal_lm_model_id = "peft-internal-testing/opt-125m-awq" | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) | ||
|
||
def tearDown(self): | ||
r""" | ||
Efficient mechanism to free GPU memory after each test. Based on | ||
https://github.com/huggingface/transformers/issues/21094 | ||
""" | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
def _check_inference_finite(self, model, batch): | ||
# try inference without Trainer class | ||
training = model.training | ||
model.eval() | ||
output = model(**batch.to(model.device)) | ||
assert torch.isfinite(output.logits).all() | ||
model.train(training) | ||
|
||
@pytest.mark.single_gpu_tests | ||
def test_causal_lm_training_awq(self): | ||
r""" | ||
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set | ||
correctly. | ||
""" | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When running the test locally, I get the following error:
Not sure if that's the one that would be fixed by the transformers PR or if it's a different issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For that you need to checkout to that transformers PR indeed, maybe we can do a version check of transformers from PEFT side, what do you think? @BenjaminBossan @BlackSamorez There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we know what version this will be contained in, this would be a possibility. It would mean that we don't have a test at all until it's released though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes ! It should be included in 4.38.0 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan @BlackSamorez that's not the error I usually get when using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan note in our daily CI we build transformers from main so IMO once the transformers PR is merged we can merge this PR ! 🙏 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like it has been merged meaning that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so this test should run successfully when we test against transformers main. Still, let's add logic to skip the test if the transformers version is too old to ensure that CI is green even when testing against the transformers release version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BenjaminBossan added @unittest.skipUnless(
version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0"),
"test requires `transformers>=4.38.0`",
) |
||
self.causal_lm_model_id, | ||
device_map="auto", | ||
) | ||
|
||
model = prepare_model_for_kbit_training(model) | ||
config = LoraConfig( | ||
r=16, | ||
lora_alpha=32, | ||
target_modules=["q_proj", "v_proj"], | ||
lora_dropout=0.05, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
model = get_peft_model(model, config) | ||
|
||
data = load_dataset("ybelkada/english_quotes_copy") | ||
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) | ||
|
||
# TODO: deal correctly with this case in transformers | ||
model._is_quantized_training_enabled = True | ||
|
||
trainer = Trainer( | ||
model=model, | ||
train_dataset=data["train"], | ||
args=TrainingArguments( | ||
per_device_train_batch_size=4, | ||
gradient_accumulation_steps=4, | ||
warmup_steps=2, | ||
max_steps=3, | ||
learning_rate=2e-4, | ||
logging_steps=1, | ||
output_dir=tmp_dir, | ||
fp16=True, | ||
), | ||
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), | ||
) | ||
model.config.use_cache = False | ||
trainer.train() | ||
|
||
model.cpu().save_pretrained(tmp_dir) | ||
|
||
assert "adapter_config.json" in os.listdir(tmp_dir) | ||
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) | ||
|
||
# assert loss is not None | ||
assert trainer.state.log_history[-1]["train_loss"] is not None | ||
|
||
@pytest.mark.multi_gpu_tests | ||
@require_torch_multi_gpu | ||
def test_causal_lm_training_multi_gpu(self): | ||
r""" | ||
Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set | ||
correctly. | ||
""" | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
self.causal_lm_model_id, | ||
device_map="auto", | ||
) | ||
|
||
assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count())) | ||
|
||
model = prepare_model_for_kbit_training(model) | ||
|
||
setattr(model, "model_parallel", True) | ||
setattr(model, "is_parallelizable", True) | ||
|
||
config = LoraConfig( | ||
r=16, | ||
lora_alpha=32, | ||
target_modules=["q_proj", "v_proj"], | ||
lora_dropout=0.05, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
model = get_peft_model(model, config) | ||
|
||
data = load_dataset("Abirate/english_quotes") | ||
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
train_dataset=data["train"], | ||
args=TrainingArguments( | ||
per_device_train_batch_size=4, | ||
gradient_accumulation_steps=4, | ||
warmup_steps=2, | ||
max_steps=3, | ||
learning_rate=2e-4, | ||
logging_steps=1, | ||
output_dir=tmp_dir, | ||
), | ||
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), | ||
) | ||
model.config.use_cache = False | ||
trainer.train() | ||
|
||
model.cpu().save_pretrained(tmp_dir) | ||
|
||
assert "adapter_config.json" in os.listdir(tmp_dir) | ||
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) | ||
|
||
# assert loss is not None | ||
assert trainer.state.log_history[-1]["train_loss"] is not None | ||
|
||
|
||
@require_torch_gpu | ||
@require_auto_awq | ||
class PeftAwqGPUTests(unittest.TestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please move this install into the installation block below (lines 60-67) to avoid creating another cache step. Also, do you think it's a good idea to fix the version like that? It means that if there is a new aqlm release that breaks something in PEFT, we wouldn't notice it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved + replaced
==
with>=