Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AQLM support for LoRA #1476

Merged
merged 14 commits into from
Feb 22, 2024
4 changes: 4 additions & 0 deletions docker/peft-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ RUN source activate peft && \
git+https://github.com/huggingface/accelerate \
peft[test]@git+https://github.com/huggingface/peft

# Add aqlm for quantization testing
RUN source activate peft && \
pip install aqlm[gpu]>=1.0.2

RUN source activate peft && \
pip freeze | grep transformers

Expand Down
22 changes: 22 additions & 0 deletions docs/source/developer_guides/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Quantization represents data with fewer bits, making it a useful technique for r
* optimizing which model weights are quantized with the [AWQ](https://hf.co/papers/2306.00978) algorithm
* independently quantizing each row of a weight matrix with the [GPTQ](https://hf.co/papers/2210.17323) algorithm
* quantizing to 8-bit and 4-bit precision with the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library
* quantizing to as low as 2-bit precision with the [AQLM](https://arxiv.org/abs/2401.06118) algorithm

However, after a model is quantized it isn't typically further trained for downstream tasks because training can be unstable due to the lower precision of the weights and activations. But since PEFT methods only add *extra* trainable parameters, this allows you to train a quantized model with a PEFT adapter on top! Combining quantization with PEFT can be a good strategy for training even the largest models on a single GPU. For example, [QLoRA](https://hf.co/papers/2305.14314) is a method that quantizes a model to 4-bits and then trains it with LoRA. This method allows you to finetune a 65B parameter model on a single 48GB GPU!

Expand Down Expand Up @@ -137,6 +138,27 @@ QLoRA adds trainable weights to all the linear layers in the transformer archite
config = LoraConfig(target_modules="all-linear", ...)
```

## AQLM quantizaion

Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and takes advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. This allows it to compress models down to as low as 2-bit with considerably low accuracy losses.

Since the AQLM quantization process is computationally expensive, a use of prequantized models is recommended. A partial list of available models can be found in the official aqlm [repository](https://github.com/Vahe1994/AQLM).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice (and better for adoption) to have safetensors for all of these models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly did safetensors for models for which we needed low RAM footprint for demos. We're currently updating the models themselves as well, and we'll definitely standardize the checkpoints once we're done.


The models support LoRA adapter tuning. To tune the quantized model you'll need to install the `aqlm` inference library: `pip install aqlm>=1.0.2`. Finetuned LoRA adapters shall be saved separately, as merging them with AQLM quantized weights is not possible.

```py
quantized_model = AutoModelForCausalLM.from_pretrained(
"BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch",
torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True,
)

peft_config = LoraConfig(...)

quantized_model = get_peft_model(quantized_model, peft_config)
```

You can refer to the [Google Colab](https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing) example for an overview of AQLM+LoRA finetuning.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding this notebook to the examples/ folder in PEFT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this notebook will suffice as an example in docs but it's not good enough to put it on GitHub. It'll probably be replaced in a few weeks anyway once we have better models, simpler pypi installs and generally better example.


## Next steps

If you're interested in learning more about quantization, the following may be helpful:
Expand Down
4 changes: 4 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,9 @@ def is_torch_tpu_available(check_device=True):
return False


def is_aqlm_available():
return importlib.util.find_spec("aqlm") is not None


def is_auto_awq_available():
return importlib.util.find_spec("awq") is not None
100 changes: 100 additions & 0 deletions src/peft/tuners/lora/aqlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import torch

from peft.import_utils import is_aqlm_available
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer


if is_aqlm_available():
from aqlm import QuantizedLinear


class AqlmLoraLinear(torch.nn.Module, LoraLayer):
def __init__(
self,
base_layer,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
result = self.base_layer(x)

if self.disable_adapters:
return result

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result += output
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep

# TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102
# def reset_lora_parameters(self, adapter_name):
# if adapter_name in self.lora_A.keys():
# torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
# torch.nn.init.zeros_(self.lora_B[adapter_name].weight)


def dispatch_aqlm(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear):
new_module = AqlmLoraLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.codes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the place where quantized linear layers get wrapped with a LoRA wrapper.
qweight itself is there simply to get it's device here.


return new_module
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = base_layer.input_size, base_layer.output_size
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear":
# AQLM QuantLinear
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
Expand Down
5 changes: 3 additions & 2 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties

from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import LoraConfig
from .gptq import dispatch_gptq
Expand Down Expand Up @@ -157,7 +158,7 @@ def _create_and_replace(
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}

quant_methods = ["gptq", "awq"]
quant_methods = ["gptq", "aqlm", "awq"]
for quant_method in quant_methods:
quantization_config = get_quantization_config(self.model, method=quant_method)
if quantization_config is not None:
Expand Down Expand Up @@ -247,7 +248,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):

dispatchers.append(dispatch_bnb_4bit)

dispatchers.extend([dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])
dispatchers.extend([dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])

new_module = None
for dispatcher in dispatchers:
Expand Down
5 changes: 3 additions & 2 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,21 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad
"""
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}

for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False

if not is_gptq_quantized:
if not is_gptq_quantized and not is_aqlm_quantized:
# cast all non INT8 parameters to fp32
for param in model.parameters():
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)

if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing:
if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing:
# When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
# For backward compatibility
Expand Down
88 changes: 88 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import importlib
import os
import tempfile
import unittest
Expand All @@ -24,6 +25,7 @@
from accelerate.test_utils.testing import run_command
from accelerate.utils import patch_environment
from datasets import Audio, DatasetDict, load_dataset
from packaging import version
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -53,6 +55,7 @@
from peft.utils import SAFETENSORS_WEIGHTS_NAME

from .testing_utils import (
require_aqlm,
require_auto_awq,
require_auto_gptq,
require_bitsandbytes,
Expand Down Expand Up @@ -1383,6 +1386,91 @@ def test_model_loaded_in_float16_working(self):
trainer.train()


@require_torch_gpu
@require_aqlm
@unittest.skipUnless(
version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0"),
"test requires `transformers>=4.38.0`",
)
class PeftAqlmGPUTests(unittest.TestCase):
r"""
AQLM + peft tests
"""

def setUp(self):
self.causal_lm_model_id = "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is stored in a pickle file, for tests we should really move to safetensors. Would it be possible for you to convert it or switch to a safetensors model for testing? Also, we should move models used for testing over to https://huggingface.co/peft-internal-testing, which I can do once we have a safetensors model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've converted the model to safetensors. The tests still pass (with this PR's transformers) and the results are consistent.

self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)

def tearDown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
torch.cuda.empty_cache()

def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
assert torch.isfinite(output.logits).all()
model.train(training)

@pytest.mark.single_gpu_tests
def test_causal_lm_training_aqlm(self):
r"""
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running the test locally, I get the following error:

    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_aqlm(self):
        r"""
        Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
>           model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map="cuda",
                torch_dtype="auto",
            )

tests/test_gpu_examples.py:1421: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:567: in from_pretrained
    return model_class.from_pretrained(
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/modeling_utils.py:3563: in from_pretrained
    hf_quantizer.postprocess_model(model)
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/quantizers/base.py:179: in postprocess_model
    return self._process_model_after_weight_loading(model, **kwargs)
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/quantizers/quantizer_aqlm.py:80: in _process_model_after_weight_loading
    model._is_quantized_training_enabled = False
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0...()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
name = '_is_quantized_training_enabled', value = False

    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
        def remove_from(*dicts_or_sets):
            for d in dicts_or_sets:
                if name in d:
                    if isinstance(d, dict):
                        del d[name]
                    else:
                        d.discard(name)
    
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
                                "(torch.nn.Parameter or None expected)"
                                )
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
                for hook in _global_module_registration_hooks.values():
                    output = hook(self, name, value)
                    if output is not None:
                        value = output
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError(f"cannot assign '{torch.typename(value)}' as child module '{name}' "
                                    "(torch.nn.Module or None expected)"
                                    )
                for hook in _global_module_registration_hooks.values():
                    output = hook(self, name, value)
                    if output is not None:
                        value = output
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' "
                                        "(torch.Tensor or None expected)"
                                        )
                    for hook in _global_buffer_registration_hooks.values():
                        output = hook(self, name, value)
                        if output is not None:
                            value = output
                    buffers[name] = value
                else:
>                   super().__setattr__(name, value)
E                   AttributeError: can't set attribute '_is_quantized_training_enabled'

../../../anaconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1747: AttributeError

Not sure if that's the one that would be fixed by the transformers PR or if it's a different issue.

Copy link
Contributor

@younesbelkada younesbelkada Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For that you need to checkout to that transformers PR indeed, maybe we can do a version check of transformers from PEFT side, what do you think? @BenjaminBossan @BlackSamorez

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know what version this will be contained in, this would be a possibility. It would mean that we don't have a test at all until it's released though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes ! It should be included in 4.38.0

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan @BlackSamorez that's not the error I usually get when using main branch transformers. That would be ValueError: The model you are trying to fine-tune is quantized with aqlm but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers to request the support for training support for aqlm which is consistent with that PR's logic, which adds the possibility of retruning positive is_trainable when aqlm's version is right.
Your transformers main is out of date and didn't catch this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan note in our daily CI we build transformers from main so IMO once the transformers PR is merged we can merge this PR ! 🙏

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it has been merged meaning that transformers main should fully support this PR's tests.
(at least that's the case on my machine)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so this test should run successfully when we test against transformers main. Still, let's add logic to skip the test if the transformers version is too old to ensure that CI is green even when testing against the transformers release version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan added

@unittest.skipUnless(
    version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0"),
    "test requires `transformers>=4.38.0`",
)

self.causal_lm_model_id,
device_map="cuda",
torch_dtype="auto",
)

model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)

trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
logging_steps=1,
output_dir=tmp_dir,
fp16=True,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()

model.cpu().save_pretrained(tmp_dir)

assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None


@require_torch_gpu
@require_auto_awq
class PeftAwqGPUTests(unittest.TestCase):
Expand Down
9 changes: 8 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import torch

from peft.import_utils import is_auto_awq_available, is_auto_gptq_available, is_optimum_available
from peft.import_utils import is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_optimum_available


def require_torch_gpu(test_case):
Expand Down Expand Up @@ -61,6 +61,13 @@ def require_auto_gptq(test_case):
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)


def require_aqlm(test_case):
"""
Decorator marking a test that requires aqlm. These tests are skipped when aqlm isn't installed.
"""
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)


def require_auto_awq(test_case):
"""
Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed.
Expand Down
Loading