Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Refactor LoRA bnb layers for faster initialization #994

Merged

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Oct 5, 2023

Partly addresses #896

Description

After speeding up normal LoRA layer initialization, this PR improves initialization speed of bnb LoRA layers.

The method to achieve this is different from the one used before, namely this time the base layer is stored as a reference on the LoRA layer. This allows us to avoid calling __init__ on the bnb layer, which is what is slow.

Notes

We cannot use the same method as for the normal LoRA layers, (i.e. calling the super class's __init__ with meta device) because the bnb layers have extra logic that still creates unnecessary weights (this is something that could potentially be fixed by bnb).

However, the way used here could also be a solution to the normal layers, so if we want to have consistency, the normal layers could be refactored to use the same approach.

Another small advantage of this approach is that we can now correctly handle any *args, **kwargs that are passed to forward. If we had this consistently in all layers, it would allow us to handle extra arguments like lora_scale in the future.

Interestingly, even though we now save the base layer as a reference, which results in a different state_dict, the existing models can still be loaded successfully. This is because the adapter state_dict is not affected by the change, so users can still load their existing adapters.

The only problem would occur if users dump the whole model, i.e. base model and adapter, using torch.save and then trying to load with torch.load. For those users, we could theoretically provide a script to convert the state_dict (i.e. renaming some keys).

To ensure that the old adapters can still be loaded successfully, I'm working at the same time on adding regression tests. I'll create a separate PR for those to avoid blowing up this one.

Tests

I ran a test on bloomz-1b1 for how long it takes to create the PeftModel, the results are:

8bit: 1108.34 ms > 26.82 ms
4bit: 1101.96 ms > 23.69 ms

Script was:

import time
from contextlib import contextmanager

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

@contextmanager
def timed(msg):
    start = time.perf_counter()
    yield
    end = time.perf_counter()
    print(f"{msg} took {1000 * (end - start):.2f} ms")

torch.manual_seed(1000)
model_id = "bigscience/bloomz-1b1"
test_8bit = True

if test_8bit:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        load_in_8bit=True,
    )
    config = LoraConfig(
        r=8,
        init_lora_weights=False,
    )
else:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=False,
        bnb_4bit_compute_type=torch.float32,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        torch_dtype=torch.float32,
    )
    config = LoraConfig(
        r=8,
        init_lora_weights=False,
    )

with timed("creating peft model"):
    get_peft_model(model, config)

Partly addresses huggingface#896

Description

After speeding up normal LoRA layer initialization, this PR improves
initialization speed of bnb LoRA layers.

The method to achieve this is different from the one used before, namely
this time the base layer is stored as a reference on the LoRA layer.
This allows us to avoid calling __init__ on the bnb layer, which is what
is slow.

Notes

We cannot use the same method as for the normal LoRA layers, (i.e.
calling the super class's __init__ with meta device) because the bnb
layers have extra logic that still creates unnecessary weights.

However, the way used here could also be a solution to the normal
layers, so if we want to have consistency, the normal layers could be
refactored to use the same approach.

Interestingly, even though we now save the base layer as a reference,
which results in a different state_dict, the existing models can still
be loaded successfully. This is because the adapter state_dict is not
affected by the change, so users can still load their existing adapters.

The only problem would occur if users dump the whole model, i.e. base
model and adapter, using torch.save and then trying to load with
torch.load. For those users, we could theoretically provide a script to
convert the state_dict (i.e. renaming some keys).

To ensure that the old adapters can still be loaded successfully, I'm
working at the same time on adding regression tests. I'll create a
separate PR for those to avoid blowing up this one.

Tests

I ran a test on bloomz-1b1 for how long it takes to create the
PeftModel, the results are:

8bit: 1108.34 ms > 26.82 ms
4bit: 1101.96 ms > 23.69 ms
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review October 5, 2023 14:07
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for speeding up the creation of quantized peft models 🚀. I like the approach and it simplifies and makes it more extendable. Left a couple of comments.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan, good to merge after resolving the conflicts.

@BenjaminBossan BenjaminBossan merged commit e98df91 into huggingface:main Oct 10, 2023
@BenjaminBossan BenjaminBossan deleted the refactor-bnb-faster-init branch October 10, 2023 14:47
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Nov 1, 2023
This is a POC to show how we could achieve mixing adapter types such as
LoRA and LoKr.

Description

The very general idea is that we can already mix multiple adapters of
the same type, e.g. add two LoRA adapters, but right now we fail when
trying to mix different types. This restriction has been lifted by
adding a new class PeftMixedModel which deals with different adapter
types.

The usage looks something like this:

    base_model = ...
    config0 = LoraConfig(...)
    # set mixed=True
    peft_model = get_peft_model(base_model, config0, mixed=True)
    config1 = LoHaConfig(...)
    peft_model.add_adapter(config1, "other")
    peft_model.set_adapter(["default", "other"])

At this point, both adapters are active at the same time.

Existing code should not be affected by this change, since users need to
opt into this behavior by setting mixed=True.

Also interesting is that this method can be used for a single adapter
type but with very different configs. Right now, we have limited support
for that (e.g. for LoRA, different r values by using rank_pattern), but
with this, we don't need to special case the differing arguments anymore.

Implementation

Apart from adding the new PeftMixedModel class to replace PeftModel, I
added a new class LycorisModel which replaces LoraModel, LoHaModel etc.
This class checks the config type and then uses the corresponding
LoraModel, LoHaModel etc. to create the adapter.

Another crucial change I had to make was to adopt the "base layer
pattern". This is the pattern that was, for instance, used to speed up
initialization in LoRA bnb layers in PR huggingface#994.

The main change is that the adapter layer wraps the original layer and
calls forward on that layer, instead of doing stuff like this:

    F.linear(
        input, transpose(self.weight, self.fan_in_fan_out)
    )

which completely circumvents the call to the target layer's forward
method. With the base layer pattern, we now call the target layer's
forward method. Therefore, if the target layer is another adapter layer,
we call its forward method correctly.

This change has the nice side benefit that we no longer need to use
_init_empty_weight -- in fact, we don't initialize any of the target
layer's weights anymore, since we have a reference to it.

Note that same as for the bnb layers, this should not be backwards
incompatible, since the adapter weights and their state_dicts are not
affected by this change.

Somewhat unrelated changes

During debugging, I got very annoyed with the fact that the reprs of
adapter layers and normal PyTorch layers are hard to distinguish, e.g.
the type is just "Linear". Now, for adapter layers, it is prefixed by
the adapter type, e.g. "lora.Linear".

TODOs

- [ ] For now, I only added this capability for LoRA and LoHa as a POC.
  It needs to be added to LoKr and AdaLora too.
- [ ] The unit tests are very rudimentary right now, only a simple model
  is tested in two settings.
- [ ] There is no documentation so far.
- [ ] I'm not yet sure if the same logic can be applied to IA³ or if it
  may fail because IA³ can apply its scaling to the input, not the output
- [ ] It is currently not possible to represent a mixed adapter model as
  a single config. I think we can come up with a solution but I don't
  think it is necessary for a first version of this.
@BenjaminBossan BenjaminBossan mentioned this pull request Nov 1, 2023
9 tasks
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Nov 10, 2023
This PR supersedes huggingface#995. The description below is copied and modified
from that PR. For some technical reasons, it was easier for me to create
a new PR than to update the previous one, sorry for that.

Description

In general, for regression tests, we need two steps:

1. Creating the regression artifacts, in this case the adapter
   checkpoint and the expected output of the model.
2. Running the regression tests, i.e. loading the adapter and checking
   that the output of the model is the same as the expected output.

My approach is to re-use as much code as possible between those two
steps. Therefore, the same test script can be used for both, with only
an environment variable to distinguish between the two. Step 1 is
invoked by calling:

`REGRESSION_CREATION_MODE=True pytest tests/regression/test_regression.py`

and to run the second step, we call:

`pytest tests/regression/test_regression.py`

Creating regression artifacts

The first step will create an adapter checkpoint and an output for the
given PEFT version and test setting in a new directory. E.g. it will
create a directory `tests/regression/lora_opt-125m_bnb_4bit/0.5.0/` that
contains adapter_model.bin and output.pt.

Before this step runs, there is a check that the git repo is clean (no
dirty worktree) and that the commit is tagged (i.e. corresponds to a
release version of PEFT). Otherwise, we may accidentally create
regression artifacts that do not correspond to any PEFT release.

The easiest way to get such a clean state (say, for PEFT v0.5.0) is by
checking out a tagged commit, e.g:

`git checkout v0.5.0`

before running the first step.

The first step will also skip the creation of regression artifacts if
they already exist.

It is possible to circumvent all the aforementioned checks by setting
the environment variable `REGRESSION_FORCE_MODE` to True like so:

`REGRESSION_FORCE_MODE=True REGRESSION_CREATION_MODE=True pytest tests/regression/test_regression.py`

You should only do this if you know exactly what you're doing.

Running regression tests

The second step is much simpler. It will load the adapters and the
output created in the first step, and compare the output to the output
from a new PEFT model using the loaded adapter. The outputs should be
the same.

If more than one version is discovered for a given test setting, all of
them are tested.

Notes

As is, the adapters are stored in the git repo itself. Since they're
relatively small, the total size of the repo is still reasonable.
However, it could be better to store those adapters on HF Hub instead.
This would, however, make things a bit more complicated (not sure how to
parse directories etc. on Hub).

The regression tests in this included in this PR were used to check that
 huggingface#994 still allows to load checkpoints created with PEFT v0.6.1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants