Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4-bit support to IA3 - Outperforms QLoRA in both speed and memory consumption #864

Merged
merged 13 commits into from
Sep 26, 2023

Conversation

His-Wardship
Copy link
Contributor

@His-Wardship His-Wardship commented Aug 25, 2023

Further to #843, I have attempted to adjust the ia3.py source code to facilitate 4-bit IA3 training. The suggested adjustments to ia3.py passed initial testing and ran inference properly. I have not yet attempted to broach merging the adapter to the model in less than full-precision, and expect this is likely beyond my month-old capabilities.

As I have cautioned in previous posts, I have literally no experience in computer programming or software design, having only begun toying around with AI (and by extension, Python) about a month ago. I wrote this PR purely by reference to the HuggingFace and PyTorch documentation, as well as by reference to the existing implementation of 4-bit training for AdaLoRA and LoRA. As such, this code may be messy or not in compliance with basic standards (as I am not aware of what these standards are!).

I have run some preliminary benchmarks and have generally found it to outperform 4-bit QLoRA in terms of both training speed and memory consumption. This is less pronounced during extremely lightweight training (low rank, low sequence length, batch size of one), as the majority of VRAM consumption is attributable to simply loading the model; the divergence becomes more apparent at more strenuous hyperparameters. In #843, I provided a brief summary table comparing 8-bit IA3 to LoRA and QLoRA. I have re-run all benchmarks (as changes have been made to my environment) and also have included new lines in relation to 4-bit IA3.

With Batch Size: 2, Gradient Accumulation: 2, Sequence Length: 4096

Training Details Training Time VRAM Consumption
4-bit IA3 42 minutes 15.90GB/18.90GB**
8-bit IA3 47 minutes 23.15GB
4-bit QLoRA - Rank 32, Alpha 128 58 minutes 19.53GB
4-bit QLoRA - Rank 128, Alpha 256 1 hour 3 minutes 21.48GB
8-bit LoRA - Rank 32, Alpha 128 N/A OOM
8-bit LoRA - Rank 128, Alpha 256 N/A OOM

With Batch Size: 1, Gradient Accumulation: 2, Sequence Length: 2048

Training Details Training Time VRAM Consumption
4-bit IA3 1 hour 4 minutes 10.13GB
8-bit IA3 54 minutes 16.29GB
4-bit QLoRA - Rank 32, Alpha 128 1 hour 10 minutes 10.17GB
4-bit QLoRA - Rank 128, Alpha 256 1 hour 11 minutes 11.91GB
8-bit LoRA - Rank 32, Alpha 128 1 hour 0 minutes 16.68GB
8-bit LoRA - Rank 128, Alpha 256 1 hour 1 minutes 18.03GB

While preliminary testing shows that this code works, I am aware that as these changes are not particularly complex and only took 45 minutes to make (especially considering my lack of experience), and as 4-bit support was left out of the original release of IA3, I assume it is likely there is a more expansive reason for its exclusion (e.g. incompatibility with other parts of PEFT repo etc.). Please do let me know if so, as I do not wish to clutter up this repo with incompatible or ill-advised mess!

Notably, the savings in memory against higher rank LoRA hyperparameters make it more practicable to train ~30B parameter models on 24GB consumer hardware without the necessity of using overly small sequence lengths or using techniques such as NVME offloading which both require semi-recent hardware and come at a significant speed cost.

It is entirely possible (and perhaps even likely) that there are significant inefficiencies in my implementation of this code (again, this was largely an exercise in educated copy-pasting from source docs) so please do feel free to replace anything that looks dubious.

** In the Batch Size: 2, Gradient Accumulation: 2, Sequence Length: 4096 test, but not the Batch Size: 1, Gradient Accumulation: 2, Sequence Length: 2048 tests, VRAM consumption for 4-bit IA3 jumps up by exactly 3,000MB about ~10-20 steps into training. I tested this multiple times using different parameters and regardless of the actual VRAM values, the jump is always exactly 3,000MB. It does not do this under other circumstances and I have not been able to determine the cause of this. Potentially, this may be related more to an error with Axolotl / my adjustments to Axolotl to facilitate IA3, as when I re-ran tests on 4-bit IA3 on a plain barebones python script that I had written myself separately, no such memory cost was observed. I would appreciate any guidance on this.

@BenjaminBossan
Copy link
Member

This looks very promising and useful, thanks for the PR. I haven't done an in-depth review yet, but at a glance this looks good.

Would you be up for adding a test to test_common_gpu.py so that we have a test for this new feature? We could assist you with that if you need help.

@BenjaminBossan
Copy link
Member

Hi, as you may have seen, there is now a merge conflict in your PR after #851 was merged. It should, however, be straightforward to resolve:

The new Linear4bit layer you wrote should now be added to this file: https://github.com/huggingface/peft/blob/main/src/peft/tuners/ia3/bnb.py

And IA3Model has moved to the following file, so please apply your changes there: https://github.com/huggingface/peft/blob/main/src/peft/tuners/ia3/model.py

If you run into trouble while resolving the merge conflict, just let us know.

@His-Wardship
Copy link
Contributor Author

Hi, as you may have seen, there is now a merge conflict in your PR after #851 was merged. It should, however, be straightforward to resolve:
...

Thank you for the heads up, I made these changes yesterday and they seem to work. I've updated the PR accordingly, let me know if I didn't do this correctly, I'm very much a beginner for GitHub, though it is an intuitive platform. I'm still playing around with benchmarking in my spare time to properly set out the comparative benefits of 4-bit IA3. Most notably, I have been able to achieve strong results fine-tuning the Llama 33B model, loaded from 16-bit weights in nf4 requiring only 19.5GB of VRAM. I'm just in the process of testing what the limits of a single 24GB card are (without any tricks like running a headless server or running display off CPU graphics, which may not be practicable for some/many people).

The principal issue I have encountered thus far is that inference is extremely slow when the IA3 adapter is loaded via PeftModelForCausalLM onto the full-weight model loaded in nf4. This was expected as similar behaviour is described for (Q)LoRA. However, merging the IA3 adapter with a quantised model is currently not supported and merging the weights in 16-bit requires ~75GB CPU RAM. While my PC can handle this due to being outrageously over-spec'd for running Cookie Clicker, I do not believe most end-users would have CPU RAM in the 64-128GB range.

I note that code was recently pushed to main to support merging (Q)LoRA at 8- and 4-bit, so I may later try and duplicate this for IA3, along with the quantlinear layer. There may be some mathematic reasons why this is impossible, but my actual empirical knowledge of the calculations that go on in the background is far too pallid to comprehend this!

Would you be up for adding a test to test_common_gpu.py so that we have a test for this new feature? We could assist you with that if you need help.

Happy to try my hand at this.

@BenjaminBossan
Copy link
Member

I note that code was recently pushed to main to support merging (Q)LoRA at 8- and 4-bit, so I may later try and duplicate this for IA3, along with the quantlinear layer. There may be some mathematic reasons why this is impossible, but my actual empirical knowledge of the calculations that go on in the background is far too pallid to comprehend this!

Yes, that's certainly worth exploring. At first glance, I think it should be possible, but with quantization, there can always be pitfalls. If you do find a solution, let's put it in a separate PR though.

Happy to try my hand at this.

Thanks a lot. Don't hesitate to ask if something is unclear.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @His-Wardship for adding the support for 4-bit IA3 and detailed summary of speedup and memory usage 🚀! Left a comment.

Please run make style and make quality to fix the code quality issues.

src/peft/tuners/ia3/model.py Show resolved Hide resolved
@His-Wardship
Copy link
Contributor Author

Apologies for leaving this unattended, I've been pre-occupied by work. I have some spare time and I'll get the above two points done tomorrow.

@His-Wardship
Copy link
Contributor Author

His-Wardship commented Sep 14, 2023

@BenjaminBossan

Apologies for the (long) delay in addressing the above comments. I have been pre-occupied with work, but given these took little time to do, I should have done them far earlier - the main roadblock was reading the documentation for Pytest as I had never used it before (nor knew what it was).

I have encountered a small issue in writing an IA3 addition to test_common_gpu.py, which is that after the refactoring of the tuners, to my (inexperienced) understanding of the Python documentation it does not seem possible to add a consolidated IA3 test without adjusting the existing tests in the script. This is because the current script uses the "from {module}" format, rather than the "import {module}" format. I understand that the "from {module}" format does not introduce the module name from which the imports are drawn into the namespace. As such, I believe this creates a conflict between Linear4bit and Linear8bitLT, which share identical names between tuners.ia3 and tuners.lora.

I thought of two approaches to resolve this, but wanted to get your views on them prior to making any changes to the rest of the testing script.

  1. Add conditional blocks to the testing script to load and unload the relevant modules as required. I dislike this approach as it does not appear to be used elsewhere, and also adds an unnecessary layer of nesting to the tests.
  2. Rewrite the test to use the "import {module}" format, which would include the parent modules in the namespace. However, this would require appending prepending the existing modules with "tuners.ia3.*" and "tuners.lora.*", which is inconsistent with the approach taken by the rest of the repo.

Please do let me know if there is an alternate, more efficient approach (or indeed, if the problem I have identified isn't even a problem at all, and there is a simple solution pre-built in Python).

@BenjaminBossan
Copy link
Member

Hi, thanks for working further on this feature and reading up on how to perform testing.

Regarding the issue you encountered about clashing names: How about importing the layers with an alias, along the lines of from peft.tuners.ia3 import Linear4bit as IA3Linear4bit. It's also possible to import a submodule, e.g. from peft.tuners import ia3; ia3.Linear4bit ....

In general, I wouldn't worry too much about the import styles in the tests, it's fine to use unusual pattern there if necessary. As you correctly observed, sticking to the style in the code base is desirable but the tests can deviate.

@His-Wardship His-Wardship reopened this Sep 19, 2023
@His-Wardship
Copy link
Contributor Author

Hi, thanks for working further on this feature and reading up on how to perform testing.

Regarding the issue you encountered about clashing names: How about importing the layers with an alias, along the lines of from peft.tuners.ia3 import Linear4bit as IA3Linear4bit. It's also possible to import a submodule, e.g. from peft.tuners import ia3; ia3.Linear4bit ....

In general, I wouldn't worry too much about the import styles in the tests, it's fine to use unusual pattern there if necessary. As you correctly observed, sticking to the style in the code base is desirable but the tests can deviate.

Thank you for your advice, I've added tests to test_common_gpu.py for 8bit and 4bit IA3. I have added aliases for both IA3 layers and Lora layers, as it is possible that additional adapter methods may be added in the future and it would be helpful (I think) to identify which is which.

Please ignore the spam of commits and closures/re-openings, I was having a bit of difficulty using the GitHub CLI after creating an editable local install of PEFT. In the end, I just manually edited the forked source files online the stone-age way.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @His-Wardship this looks very solid and it's basically good to merge. Also great that you added the 8bit test.

I just noticed a mistake I made earlier which is tangentially related to your PR. It would be fantastic if you could fix it, see my comment.

src/peft/tuners/ia3/bnb.py Outdated Show resolved Hide resolved
Add guard to check for BNB
@His-Wardship
Copy link
Contributor Author

His-Wardship commented Sep 20, 2023

Thanks @BenjaminBossan, I've made those changes (I think!).

Only remaining point is that I looked at creating a loading test for IA3, however I noted that the other loading tests (e.g. LoRA) use premade sample adapters which are hosted by an admin account on HF. Naturally, I can't upload there! I am unsure if HF has standard content for test adapters, but if it's just a random load of wikitext for the purpose of making an adapter, I'm happy to add some code and provide a dummy adapter to be uploaded to whichever admin account hosts the test adapters.

Also, you may have seen the mention above from Winglian, that they're apparently ready to add support for IA3 to Axolotl as soon as this PR is merged. Hopefully this improves collective awareness of IA3, especially as it is considerably more practicable to fine-tune ~34B models using 4-bit IA3 on consumer 24GB GPUs, whereas its quite awkward with 4-bit QLoRA (painfully slow, likely requires running headless terminal to save on a few MB of VRAM etc.).

@BenjaminBossan
Copy link
Member

Only remaining point is that I looked at creating a loading test for IA3, however I noted that the other loading tests (e.g. LoRA) use premade sample adapters which are hosted by an admin account on HF.

Which test are you referring to exactly? I think we can proceed without that test and add it in a separate PR, where we can upload testing artifacts to the HF account if needed.

@His-Wardship
Copy link
Contributor Author

His-Wardship commented Sep 21, 2023

Which test are you referring to exactly? I think we can proceed without that test and add it in a separate PR, where we can upload testing artifacts to the HF account if needed.

I'm referring to test_common_gpu/_lora_bnb_4bit_quantization_from_pretrained_safetensors (test copied below for ease of reference). It's a very straightforward test that simply loads a model and a predefined PEFT adapter and tests that it is capable of running inference. Happy to put in another PR at a later point on this for IA3, if at all helpful. The same code can be duplicated, it just needs pointing to a separate peft_model_id.

    def test_lora_bnb_4bit_quantization_from_pretrained_safetensors(self):
        r"""
        Test that tests if the 4bit quantization using LoRA works as expected with safetensors weights.
        """
        model_id = "facebook/opt-350m"
        peft_model_id = "ybelkada/test-st-lora"

        model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
        model = PeftModel.from_pretrained(model, peft_model_id)

        _ = model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))

Time permitting, the next item I would want to address in a separate PR (or if anyone else addresses in the interim) is merging IA3 to base model weights in 4-bit. I'm currently merging weights by loading everything in bf16 and then requantizing afterwards using a separate set of scripts I've written. This works fine, but is slow and requires all of my PC's resources for a ~34B model (24GB VRAM and 128GB system RAM), and I doubt these resource costs (particularly the 128GB system RAM part) are feasible for many people on consumer devices. The limiting factor for me will be finding the time to read all of the documentation for quantization / reading main parts of BNB code - as I've mentioned before, I have no prior knowledge in this area, so I have to read the docs before I can do anything! (I'm still only partway through the PyTorch documentation, it makes some of my regulatory handbooks look light by comparison!)

@BenjaminBossan
Copy link
Member

I'm referring to test_common_gpu/_lora_bnb_4bit_quantization_from_pretrained_safetensors (test copied below for ease of reference).

Yes, in that case I think we can add it later.

Time permitting, the next item I would want to address in a separate PR (or if anyone else addresses in the interim) is merging IA3 to base model weights in 4-bit.

This would indeed be useful to have. Speaking of which, this PR should also add an error when trying to merge 4bit, like this one, right?

if getattr(self.model, "is_loaded_in_8bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode")

Add error for merging in 4-bit
@His-Wardship
Copy link
Contributor Author

This would indeed be useful to have. Speaking of which, this PR should also add an error when trying to merge 4bit, like this one, right?

if getattr(self.model, "is_loaded_in_8bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode")

It probably should - I've added a warning saying as much.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, thanks a lot for adding 4bit to IA³.

Ideally, we can have a second pair of eyes by @younesbelkada or @pacman100

@His-Wardship
Copy link
Contributor Author

@BenjaminBossan

Thank you for the approval - I've received a notification that conflicts exist due to a recent update to the main branch. There's a button beside this notification which simply reads "Resolve conflicts" - do I just click this or do I need to rewrite the PR code somewhere?

Having read the changes to the IA3 related files, it seems as if there are just extra lines which do not impinge upon the PR, so does GitHub just adjust for these automatically upon clicking "Resolve conflicts"?

Thank you!

@BenjaminBossan
Copy link
Member

@His-Wardship There was a pretty big PR being merged recently, #873, which I think causes the merge conflicts. Basically, you need to port over the changes of that PR to this PR.

If you click "Resolve conflicts", GitHub will bring you to an online text editor, which shows the current code and the changed code on top of each other. You can use that editor to solve the conflict. If you don't want to do that in the editor, you can just cancel. So it's safe to click that button and take a look.

Alternatively, you could fetch the latest state of the main branch from upstream with git and merge it into your main branch. This still requires you to resolve the merge conflict, but you can do it locally on your editor. Choose whatever method you prefer.

To give an example, one of the conflicts is:

<<<<<<< main
        loaded_in_8bit = optionnal_kwargs["loaded_in_8bit"]
        loaded_in_4bit = optionnal_kwargs["loaded_in_4bit"]
        current_key = optionnal_kwargs["current_key"]
=======
        loaded_in_8bit = optional_kwargs["loaded_in_8bit"]
        current_key = optional_kwargs["current_key"]
>>>>>>> main

It says "main" twice, because you worked on the "main" branch of your fork (next time, it's better to create a new branch with a descriptive name). As you can see, the top code is yours, below is the current code. Git could not automatically resolve the conflict, since both pieces have changes: You added loaded_in_4bit, and our main fixed the typo in optional_kwargs. So the code you need to create must have the loaded_in_4bit line + the typo fix. Same logic goes for the other conflicting code. Does that make sense?

If you have any questions about resolving merge conflicts, let us know. There are also a ton of resources online.

Resolve conflicts with PR #873
Resolve conflicts with PR #873
@His-Wardship
Copy link
Contributor Author

@BenjaminBossan Thank you for your detailed explanation, and tips on proper organisation for future PRs. I've made the changes that were marked as conflicts - I actually thought the "optionnal" typo was intentional, given how widely it was used, and assumed it was a pun on "nn" being used as an abbreviation for neural network (as in, the cuDNN library).

Hopefully that's everything resolved!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THanks a lot for your great contribution @His-Wardship - for me the changes look great, before merging can you run the styling checks?

make style && make quality

Remove accidental blank indent for make quality test
@His-Wardship
Copy link
Contributor Author

THanks a lot for your great contribution @His-Wardship - for me the changes look great, before merging can you run the styling checks?

make style && make quality

Thank you @younesbelkada. I had accidentally hit tab on a blank line before committing, so it failed the make style/quality test. I've fixed that now and re-run the tests locally, which passed.

@His-Wardship
Copy link
Contributor Author

@BenjaminBossan Just a heads-up, I've resolved the conflicts introduced by #905. I had a lot of pain with requires_grad while building fine-tuning scripts for IA3 and had to fiddle a lot with loss calculation or force certain settings to make it run, happy to know I wasn't the only one! I've re-run make style && make quality, so hopefully everything is good to go now before another conflict arises.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, almost perfect from my point of view. Only two small issues, please take a look.

src/peft/tuners/ia3/bnb.py Outdated Show resolved Hide resolved
src/peft/tuners/ia3/bnb.py Outdated Show resolved Hide resolved
Re-arrange to remove unnecessary parenthesis.
@His-Wardship
Copy link
Contributor Author

Thanks a lot, almost perfect from my point of view. Only two small issues, please take a look.

Will do - that's strange, make quality and make style on my local install returned nothing "All done! ✨ 🍰 ✨ / 104 files left unchanged." I've fixed these points, sorry for the back-and-forth, I need to fully read the documentation on using local editable installs properly. Make style actually deletes the empty final line if I re-insert it (for the avoidance of doubt, I re-inserted it both with and without indentation, it was removed in both cases).

@BenjaminBossan
Copy link
Member

Maybe the issue you encounter could have to do with your editor or, if you use them, pre-commit hooks. E.g. some editors automatically add an empty line at the end, so you may end up with two empty lines or something. It's a bit of an annoying issue, because it's hard to spot.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work, thanks a lot @His-Wardship for your hard work!

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @His-Wardship for the adding the support for bnb 4-bit quantization for IA3, very impactful. Also, Thank you for all the discussions and comparative analysis wrt LoRA, insightful! 🤗🚀✨

@BenjaminBossan BenjaminBossan merged commit 08b6665 into huggingface:main Sep 26, 2023
11 checks passed
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Sep 26, 2023
Notes:

- Add guard to IA³ Linear8bitLt definition (should have already been there).
- Merging not supported (yet).
@scissorstail
Copy link

To someone who encountered the RuntimeError: expected scalar type Float but found Half error on 4bit...

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        torch_dtype=torch.float32, # Use torch.float32 HERE!!!
        device_map=device_map,
        low_cpu_mem_usage=True,
    )
    model = PeftModel.from_pretrained(
        model,
        LORA_WEIGHTS,
    )

@zengwufu
Copy link

zengwufu commented Jan 23, 2024

I also encountered the erro: RuntimeError: expected scalar type Float but found BFloat16.
In train stage, i use the config:

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(
    "llama2-7b",
    quantization_config=bnb_config,
    use_cache=False,
)
peft_config = IA3Config(task_type=TaskType.CAUSAL_LM)
peft_model = get_peft_model(model=model, peft_config=peft_config)
...
# using the trainer to train the peft_model
trainer = Trainer(model=peft_model, ...)
trainer.train()
# save the fine-tune model using the trainer
trainer.save_model("final_model")

The training process is normal.
In the evaluate stage, i use the config:

# load model from the saved dir
model = AutoPeftModelForCausalLM.from_pretrained(
    "final_model",
    torch_dtype=torch.bfloat16,
    load_in_4bit=True
)
...
# generate 
model.generate(...)

If i change the torch_dtype=torch.bfloat16 to torch_dtype=torch.bfloat32 in the evaluate config, the program no longer reports errors, but the result of generation are very confusing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants