Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AQLM support for LoRA #1476

Merged
merged 14 commits into from
Feb 22, 2024
Merged

AQLM support for LoRA #1476

merged 14 commits into from
Feb 22, 2024

Conversation

BlackSamorez
Copy link
Contributor

This PR aims to add AQLM support for LoRA finetuning.

AQLM has recently been integrated into transformers with this PR and it would make sense to add fine-tuning support for it.

From the aqlm inference side, proper autograd integration needs to be implemented. The work is being done in this branch. The basic code already works, but the efficient kernels are not willing to work just yet.

@BlackSamorez
Copy link
Contributor Author

Proof of concept finetuning Mixtral on Colab:
https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing

@BlackSamorez
Copy link
Contributor Author

@younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work @BlackSamorez - as always !
Can you merge your changes with the latest changes from main
I am not 100% familiar with AQLM yet, I also have an open question, do you think merging the adapter weights into the AQLM base model is something that is doable?
For reference, this is how merging is performed in a classic LoRA: https://huggingface.co/docs/peft/v0.7.1/en/conceptual_guides/lora#merge-lora-weights-into-the-base-model which is also supported in QLoRA

@BlackSamorez
Copy link
Contributor Author

I believe merging would not be possible because AQLM forces a very strong and specific symmetry on the weights (in the case of one codebook it's a limited amount of repeating local weight patterns). An abstract LoRA adapter would not satisfy said symmetry.

@BenjaminBossan
Copy link
Member

Thanks a lot for adding support for AQLM. I'm not sure what the state of the PR is, if it's ready for review or still in progress. LMK if you want to have a full review.

From a first glance, here are some things that are still missing:

  1. Import of aqlm should be guarded (possibly with min version) so that users don't get an error when it's not installed
  2. We should add documentation (could be later PR but it's not ideal)
  3. We should add tests (could be later PR but it's not ideal)

@BlackSamorez
Copy link
Contributor Author

I'm working on each of those points (looking at #1399 as a reference).

@younesbelkada
Copy link
Contributor

Thanks so much @BlackSamorez !

@BlackSamorez
Copy link
Contributor Author

@BenjaminBossan
What would the correct page docs for this be?

@BenjaminBossan
Copy link
Member

What would the correct page docs for this be?

We have a section dedicated to quantization.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BlackSamorez for all the work wrt AQLM support for LoRA!

Went over the PR and left a comment. Overall looks good!


if is_aqlm_available() and isinstance(target_base_layer, AqlmQuantizedLinear):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.codes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the place where quantized linear layers get wrapped with a LoRA wrapper.
qweight itself is there simply to get it's device here.

@BlackSamorez
Copy link
Contributor Author

I think I've addressed all of the issues above. We'll have to wait for the aqlm 1.0.2 (corresponding PR) release though because it will add proper autograd.
I'll tag you once it's released.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this method so quickly, it looks really promising.

I have a couple of suggestions, but those should be easy to adjust. Please take a look.

@@ -46,6 +46,10 @@ RUN source activate peft && \
RUN source activate peft && \
python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl

# Add aqlm for quantization testing
RUN source activate peft && \
pip install aqlm[gpu]==1.0.2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please move this install into the installation block below (lines 60-67) to avoid creating another cache step. Also, do you think it's a good idea to fix the version like that? It means that if there is a new aqlm release that breaks something in PEFT, we wouldn't notice it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved + replaced == with >=

docs/source/developer_guides/quantization.md Outdated Show resolved Hide resolved

Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and take advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. This allows it to compress models down to as low as 2-bit with considerably low accuracy losses.

Since the AQLM quantization process is computationally expensive, a use of prequantized models is recommended. A partial list of available models can be found in the official aqlm [repository](https://github.com/Vahe1994/AQLM).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice (and better for adoption) to have safetensors for all of these models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly did safetensors for models for which we needed low RAM footprint for demos. We're currently updating the models themselves as well, and we'll definitely standardize the checkpoints once we're done.

docs/source/developer_guides/quantization.md Outdated Show resolved Hide resolved
quantized_model = get_peft_model(quantized_model, peft_config)
```

You can refer to the [Google Colab](https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing) example for an overview of AQLM+LoRA finetuning.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding this notebook to the examples/ folder in PEFT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this notebook will suffice as an example in docs but it's not good enough to put it on GitHub. It'll probably be replaced in a few weeks anyway once we have better models, simpler pypi installs and generally better example.

super().__init__()
LoraLayer.__init__(self, base_layer)

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to just do self.base_layer = base_layer here. Backwards compatibility is not an issue here, since, unlike for GPTQ, this is a new class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base layer is initialized during LoraLayer.__init__(self, base_layer). I removed the self.quant_linear because it is, indeed, not needed



if is_aqlm_available():
from aqlm import QuantizedLinear as AqlmQuantizedLinear
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the alias? There is no name conflict because the PEFT class is called QuantLinear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we don't. I simply didn't like those two names being similar.
I decided to rename QuantLinear -> AqlmLoraLinear for better readability of model structure.

docs/source/developer_guides/quantization.md Show resolved Hide resolved
@BlackSamorez
Copy link
Contributor Author

@BenjaminBossan
I've addressed the issues above. Also, aqlm==1.0.2 has been released and I made sure that the demo is functional.
The tests I've added here, however, are blocked by another PR into transformers.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean ! thanks very much ! LGTM once @BenjaminBossan approves !

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for addressing the issues. From my point of view, the PR is almost ready to go, but we need to take care of a few issues with the test. Please take a look at my comments.

The tests I've added here, however, are blocked by another huggingface/transformers#29142.

If this means that the test is currently expected to fail, we need to take care of that. We don't want to have a failing CI until the next transformers release with the fix is published. Therefore, we need some kind of check on the test to ensure that it is skipped if the transformers version does not contain the fix.

"""

def setUp(self):
self.causal_lm_model_id = "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is stored in a pickle file, for tests we should really move to safetensors. Would it be possible for you to convert it or switch to a safetensors model for testing? Also, we should move models used for testing over to https://huggingface.co/peft-internal-testing, which I can do once we have a safetensors model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've converted the model to safetensors. The tests still pass (with this PR's transformers) and the results are consistent.

correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running the test locally, I get the following error:

    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_aqlm(self):
        r"""
        Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
>           model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map="cuda",
                torch_dtype="auto",
            )

tests/test_gpu_examples.py:1421: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:567: in from_pretrained
    return model_class.from_pretrained(
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/modeling_utils.py:3563: in from_pretrained
    hf_quantizer.postprocess_model(model)
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/quantizers/base.py:179: in postprocess_model
    return self._process_model_after_weight_loading(model, **kwargs)
../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/quantizers/quantizer_aqlm.py:80: in _process_model_after_weight_loading
    model._is_quantized_training_enabled = False
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0...()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
name = '_is_quantized_training_enabled', value = False

    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
        def remove_from(*dicts_or_sets):
            for d in dicts_or_sets:
                if name in d:
                    if isinstance(d, dict):
                        del d[name]
                    else:
                        d.discard(name)
    
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
                                "(torch.nn.Parameter or None expected)"
                                )
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
                for hook in _global_module_registration_hooks.values():
                    output = hook(self, name, value)
                    if output is not None:
                        value = output
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError(f"cannot assign '{torch.typename(value)}' as child module '{name}' "
                                    "(torch.nn.Module or None expected)"
                                    )
                for hook in _global_module_registration_hooks.values():
                    output = hook(self, name, value)
                    if output is not None:
                        value = output
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' "
                                        "(torch.Tensor or None expected)"
                                        )
                    for hook in _global_buffer_registration_hooks.values():
                        output = hook(self, name, value)
                        if output is not None:
                            value = output
                    buffers[name] = value
                else:
>                   super().__setattr__(name, value)
E                   AttributeError: can't set attribute '_is_quantized_training_enabled'

../../../anaconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1747: AttributeError

Not sure if that's the one that would be fixed by the transformers PR or if it's a different issue.

Copy link
Contributor

@younesbelkada younesbelkada Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For that you need to checkout to that transformers PR indeed, maybe we can do a version check of transformers from PEFT side, what do you think? @BenjaminBossan @BlackSamorez

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know what version this will be contained in, this would be a possibility. It would mean that we don't have a test at all until it's released though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes ! It should be included in 4.38.0

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan @BlackSamorez that's not the error I usually get when using main branch transformers. That would be ValueError: The model you are trying to fine-tune is quantized with aqlm but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers to request the support for training support for aqlm which is consistent with that PR's logic, which adds the possibility of retruning positive is_trainable when aqlm's version is right.
Your transformers main is out of date and didn't catch this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan note in our daily CI we build transformers from main so IMO once the transformers PR is merged we can merge this PR ! 🙏

Copy link
Contributor Author

@BlackSamorez BlackSamorez Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it has been merged meaning that transformers main should fully support this PR's tests.
(at least that's the case on my machine)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so this test should run successfully when we test against transformers main. Still, let's add logic to skip the test if the transformers version is too old to ensure that CI is green even when testing against the transformers release version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan added

@unittest.skipUnless(
    version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0"),
    "test requires `transformers>=4.38.0`",
)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for addressing the last concern, this LGTM now.

We should not forget to move a copy of the model to our internal testing repo, but that can be done in a follow up PR.

I'll leave the merging to @younesbelkada in case he wants to double check the last few changes.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work ! Thanks so much for your great work @BlackSamorez !

@younesbelkada younesbelkada merged commit 23213ca into huggingface:main Feb 22, 2024
14 checks passed
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
* aqlm

* Style and copied tests

* aqlm import guadr

* docs

* correct model in tests

* Update docs/source/developer_guides/quantization.md

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update docs/source/developer_guides/quantization.md

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* moved aqlm install and added >=

* Removed `quant_linear_module`

* AqlmLoraLinear

* docs update

* transformers version check

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants