From 25f015d98bbf8abea42734349bf9aed12d7b48ad Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Sun, 19 Nov 2023 13:13:56 -0500 Subject: [PATCH 01/10] Add LoftQ method ingrated into LoRA. Add example code for LoftQ usage. --- README.md | 1 + examples/loftq_finetuning/README.md | 68 ++ .../loftq_finetuning/quantize_save_load.py | 198 ++++ .../loftq_finetuning/train_gsm8k_llama.py | 871 ++++++++++++++++++ src/peft/__init__.py | 1 + src/peft/tuners/__init__.py | 2 +- src/peft/tuners/lora/__init__.py | 4 +- src/peft/tuners/lora/config.py | 57 +- src/peft/tuners/lora/layer.py | 44 +- src/peft/tuners/lora/model.py | 4 + src/peft/utils/loftq_utils.py | 200 ++++ 11 files changed, 1438 insertions(+), 12 deletions(-) create mode 100644 examples/loftq_finetuning/README.md create mode 100644 examples/loftq_finetuning/quantize_save_load.py create mode 100644 examples/loftq_finetuning/train_gsm8k_llama.py create mode 100644 src/peft/utils/loftq_utils.py diff --git a/README.md b/README.md index d4dfee5c38..0db5a1d406 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Supported methods: 7. MultiTask Prompt Tuning: [Multitask Prompt Tuning Enables Parameter-Efficient Transfer Learning](https://arxiv.org/abs/2303.02861) 8. LoHa: [FedPara: Low-Rank Hadamard Product for Communication-Efficient Federated Learning](https://arxiv.org/abs/2108.06098) 9. LoKr: [KronA: Parameter Efficient Tuning with Kronecker Adapter](https://arxiv.org/abs/2212.10650) based on [Navigating Text-To-Image Customization:From LyCORIS Fine-Tuning to Model Evaluation](https://arxiv.org/abs/2309.14859) implementation +10. LoftQ: [LoftQ: LoRA-Fine-Tuning-aware Quantization for Large Language Models](https://arxiv.org/abs/2310.08659) ## Getting started diff --git a/examples/loftq_finetuning/README.md b/examples/loftq_finetuning/README.md new file mode 100644 index 0000000000..02a005a6a5 --- /dev/null +++ b/examples/loftq_finetuning/README.md @@ -0,0 +1,68 @@ +# LoftQ: LoRA-fine-tuning-aware Quantization + +## Introduction + +LoftQ provides better initialization for LoRA adaptors A and B, +and the Quantization of pre-trained weights W. + +## Quantization +We recommend to save the quantized backbone model as fp16/fp32 +and load it as [NormalFloat4](https://arxiv.org/abs/2305.14314). + +We provide a simple example to show how to quantize llama-2-7b model and save/load it. + +```sh +python quantize_save_load.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --token HF_TOKEN \ + --bits 4 --iter 5 --rank 16 \ + --save_dir model_zoo/loftq/ +``` + +- `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama). +- `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters. +It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank-backbone`, +and the LoRA adapters are at `Llama-2-7b-hf-4bit-16rank-adapters`. + +## Fine-tuning + +Here is an example to load the quantized backbone and LoRA adapters: + +```python +import os + +from transformers import AutoModelForCausalLM +from peft import PeftModel + + +base_model = AutoModelForCausalLM.from_pretrained(os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-backbone"), + load_in_4bit=True, + ) +peft_model = PeftModel.from_pretrained(base_model, + os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-adapters", + is_trainable=True), + ) +``` + +We also provide an example to fine-tune LoftQ on GSM8K. +We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub](https://huggingface.co/LoftQ). + +```sh +python train_gsm8k_llama.py \ + --model_name_or_path LoftQ/Llama-2-7b-hf-6bit-64rank-backbone \ + --adapter_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank-adapters \ + --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \ + --learning_rate 3e-4 \ + --seed 202 \ + --dataset_name gsm8k \ + --dataset_config main \ + --pad_to_max_length \ + --max_source_length 128 \ + --max_target_length 256 \ + --num_train_epochs 5 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --with_tracking \ + --report_to tensorboard +``` diff --git a/examples/loftq_finetuning/quantize_save_load.py b/examples/loftq_finetuning/quantize_save_load.py new file mode 100644 index 0000000000..3966dea607 --- /dev/null +++ b/examples/loftq_finetuning/quantize_save_load.py @@ -0,0 +1,198 @@ +import argparse +import os + +import torch +import torch.nn as nn +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoTokenizer, +) + +from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model + + +class Shell(nn.Module): + def __init__(self, weight, bias=None): + super().__init__() + self.weight = nn.Parameter(weight, requires_grad=False) + if bias is not None: + self.bias = nn.Parameter(bias, requires_grad=False) + + +def unwarap_model(model, sub_module_name=".base_layer"): + sub_module_name_list = [k.split(sub_module_name)[0] for k in model.state_dict().keys() if sub_module_name in k] + sub_module_name_set = set(sub_module_name_list) + for name in sub_module_name_set: + # get the parent of the submodule + name_parent = ".".join(name.split(".")[:-1]) + name_child = name.split(".")[-1] + sub_module = model.get_submodule(name_parent) + print(sub_module) + + # replace with shell + child = getattr(sub_module, name_child) + weight = getattr(child.base_layer, "weight", None) + bias = getattr(child.base_layer, "bias", None) + shell = Shell(weight, bias) + + setattr(sub_module, name_child, shell) + + print("You have unwrapped the model. Use it on your own risk.") + + +def print_model(model, name): + print("=" * 10 + name + "=" * 10) + print(model) + for name, param in model.named_parameters(): + if torch.is_tensor(param): + if param.dtype in [torch.float32, torch.float16]: + print( + name, + param.shape, + param.device, + param.dtype, + param.requires_grad, + param.mean().item(), + param.max().item(), + ) + else: + print(name, param.shape, param.device, param.dtype, param.requires_grad) + + +def arg_parse(): + parser = argparse.ArgumentParser(description="Quantize a model with LoftQ.") + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + required=True, + help="The name or path of the fp32/16 model.", + ) + parser.add_argument( + "--token", + type=str, + default=None, + help="The access token to download model from HuggingFace Hub.", + ) + parser.add_argument( + "--bits", + type=int, + default=4, + help="The quantized bits", + ) + parser.add_argument( + "--iter", + type=int, + default=1, + help="The alternating steps in LoftQ", + ) + parser.add_argument( + "--rank", + type=int, + default=16, + help="The rank of the LoRA adapter", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./model_zoo/loftq/", + help="The rank of the LoRA adapter", + ) + args = parser.parse_args() + return args + + +def quantize_and_save(): + args = arg_parse() + + # Download weights and configure LoRA + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token) + if "llama" in args.model_name_or_path.lower(): + model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto") + task_type = TaskType.CAUSAL_LM + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"] + + elif "bart" in args.model_name_or_path.lower(): + model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto") + task_type = TaskType.SEQ_2_SEQ_LM + target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"] + + elif "deberta" in args.model_name_or_path.lower(): + model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token) + model = model.cuda() + task_type = TaskType.SEQ_CLS + target_modules = ["query_proj", "key_proj", "value_proj", "dense"] # embeddings not supported by peft + else: + raise NotImplementedError("Other models not supported yet.") + + # Config of LoftQ + loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter, loftq_fake=True) + + lora_config = LoraConfig( + task_type=task_type, + inference_mode=True, + r=args.rank, + lora_alpha=args.rank, + lora_dropout=0.1, + target_modules=target_modules, + init_lora_weights="loftq", + loftq_config=loftq_config, + ) + + # Obtain LoftQ model + lora_model = get_peft_model(model, lora_config) + base_model = lora_model.get_base_model() + + # Save LoftQ model + model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank" + base_model_dir = os.path.join(args.save_dir, model_name + "-backbone") + lora_model_dir = os.path.join(args.save_dir, model_name + "-adapters") + + # save lora adapters first + lora_model.base_model.peft_config[ + "default" + ].base_model_name_or_path = base_model_dir # This can be a local path or Hub model id + lora_model.base_model.peft_config["default"].init_lora_weights = True # Don't apply LoftQ when loading again + + lora_model.save_pretrained(lora_model_dir) + print_model(lora_model, "lora_model") + + # remove lora adapters and save the backbone + unwarap_model(base_model) + base_model.save_pretrained(base_model_dir) + tokenizer.save_pretrained(base_model_dir) + + print_model(base_model, "base_model") + + return base_model_dir, lora_model_dir + + +def load_loftq(base_model_path, lora_adapter_path): + if "llama" in base_model_path.lower(): + model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map="auto", load_in_4bit=True) + elif "bart" in base_model_path.lower(): + model = AutoModelForSeq2SeqLM.from_pretrained(base_model_path, device_map="auto", load_in_4bit=True) + elif "deberta" in base_model_path.lower(): + model = AutoModelForSequenceClassification.from_pretrained(base_model_path, load_in_4bit=True) + else: + raise NotImplementedError("Other models not supported yet.") + + lora_model = PeftModel.from_pretrained(model, lora_adapter_path, is_trainable=True) + + # Do training or inference below + print_model(lora_model, "lora_model") + print_model(model, "base_model") + + +if __name__ == "__main__": + base_dir, lora_dir = quantize_and_save() + load_loftq(base_dir, lora_dir) + +# example command: +# python quantize_save_load.py \ +# --model_name_or_path meta-llama/Llama-2-7b-hf \ +# --token XXX \ +# --bits 4 --iter 5 --rank 16 \ +# --save_dir ./model_zoo/loftq/ diff --git a/examples/loftq_finetuning/train_gsm8k_llama.py b/examples/loftq_finetuning/train_gsm8k_llama.py new file mode 100644 index 0000000000..d0ca714d45 --- /dev/null +++ b/examples/loftq_finetuning/train_gsm8k_llama.py @@ -0,0 +1,871 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import argparse +import copy +import logging +import math +import os +import random +import re +from pathlib import Path + +import datasets +import torch +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import Repository, create_repo +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import send_example_telemetry +from transformers.utils.versions import require_version + +from peft import PeftModel + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +# check_min_version("4.32.0.dev0") + +logger = get_logger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) +HF_TOKEN = "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" + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv, txt or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv, txt or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--ignore_pad_token_for_loss", + type=bool, + default=True, + help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.", + ) + parser.add_argument( + "--max_source_length", + type=int, + default=128, + help=( + "The maximum total input sequence length after " + "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded." + ), + ) + parser.add_argument( + "--max_target_length", + type=int, + default=128, + help=( + "The maximum total sequence length for target text after " + "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." + "during ``evaluate`` and ``predict``." + ), + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--trust_remote_code", + type=bool, + default=False, + help=( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" + "should only be set to `True` for repositories you trust and in which you have read the code, as it will" + "execute code present on the Hub on your local machine." + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--low_cpu_mem_usage", + action="store_true", + help=( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "If passed, LLM loading time and RAM consumption will be benefited." + ), + ) + ########################## + # Generation Config # + ########################## + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="temperature of 1.0 has no effect, lower tend toward greedy sampling", + ) + parser.add_argument("--k", type=int, default=40, help="Choose k candidate words") + parser.add_argument("--p", type=float, default=0.95, help="The sum of probability of candidate words is 0.9 ") + + ########################## + # Exp Args # + ########################## + parser.add_argument( + "--adapter_name_or_path", + type=str, + default=None, + help=( + "The LoRA adapter checkpoint. Set None if you want to fine-tune from LoftQ." + "Specify a path if you want to evaluate." + ), + ) + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["project_dir"] = args.output_dir + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + # Retrieve of infer repo_name + repo_name = args.hub_model_id + if repo_name is None: + repo_name = Path(args.output_dir).absolute().name + # Create repo and retrieve repo_id + repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id + # Clone repo locally + repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained( + args.config_name, + trust_remote_code=args.trust_remote_code, + ) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained( + args.model_name_or_path, + trust_remote_code=args.trust_remote_code, + ) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code + ) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, + use_fast=not args.use_slow_tokenizer, + trust_remote_code=args.trust_remote_code, + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + ########################## + # Tokenizer # + ########################## + tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token + tokenizer.padding_side = "left" # Allow batched inference + tokenizer.truncation_side = "left" + + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + low_cpu_mem_usage=True, + load_in_4bit=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + ), + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config, trust_remote_code=args.trust_remote_code) + + ########################## + # Peft Model # + ########################## + if args.adapter_name_or_path is None: + args.adapter_name_or_path = args.model_name_or_path + model = PeftModel.from_pretrained(model, args.adapter_name_or_path, is_trainable=True) + model.print_trainable_parameters() + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + ########################## + # GSM8K dataset # + ########################## + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + + # Get the column names for source/target. + source_column, target_column = "question", "answer" + + # Temporarily set max_target_length for training. + padding = "max_length" if args.pad_to_max_length else False + task_prompt = "\nAnswer the above question. First think step by step and then answer the final number.\n" + + def prompt_process(sent_1, sent_2, prompt_1="", prompt_2="", prompt_3=""): + sent_2 = sent_2.replace("####", "The final answer is") + return prompt_1 + sent_1 + prompt_2 + sent_2 + prompt_3 + + def preprocess_function_train(examples): + sources = examples[source_column] + targets = examples[target_column] + + inputs = [prompt_process(source, target, prompt_2=task_prompt) for (source, target) in zip(sources, targets)] + + model_inputs = tokenizer( + inputs, + max_length=args.max_source_length + args.max_target_length, + padding=padding, + truncation=True, + return_tensors="pt", + ) + + labels = copy.deepcopy(model_inputs) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and args.ignore_pad_token_for_loss: + # get the length of the target tokens. -1 to kick out the token + target_tokens = tokenizer(targets, padding=False) + target_len = [len(label) - 1 for label in target_tokens["input_ids"]] + + # don't calculate the loss from source and padding (left padding) + for i in range(len(labels["input_ids"])): + labels["input_ids"][i, : -target_len[i]] = -100 + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + def preprocess_function_test(examples): + sources = examples[source_column] + labels = examples[target_column] + + inputs = [source + task_prompt for source in sources] + + model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) + labels = tokenizer(labels, max_length=args.max_target_length, padding=padding, truncation=True) + + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + with accelerator.main_process_first(): + train_dataset = raw_datasets["train"].map( + preprocess_function_train, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on training dataset", + ) + + eval_dataset = raw_datasets["test"].map( + preprocess_function_test, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on test dataset", + ) + + # Log a few random samples from the set: + for index in random.sample(range(len(train_dataset)), 2): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + for index in random.sample(range(len(eval_dataset)), 2): + logger.info(f"Sample {index} of the validation set: {eval_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and "lora" in n], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("clm_no_trainer", experiment_config) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + checkpoint_path = args.resume_from_checkpoint + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + checkpoint_path = path + path = os.path.basename(checkpoint_path) + + accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") + accelerator.load_state(path) + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + resume_step -= starting_epoch * len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_steps + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking: + total_loss = 0 + if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + active_dataloader = train_dataloader + for step, batch in enumerate(active_dataloader): + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + accelerator.backward(loss) + accelerator.print(f"Epoch: {epoch} | Step: {step} | Loss: {loss}") + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if completed_steps >= args.max_train_steps: + break + + model.eval() + gen_kwargs = { + "max_new_tokens": args.max_target_length, + "temperature": args.temperature, + "top_k": args.k, + "top_p": args.p, + "do_sample": True, + } + ans_pred_list = [] + ans_gold_list = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + gen_kwargs["input_ids"] = batch["input_ids"] + gen_kwargs["attention_mask"] = batch["attention_mask"] + generated_tokens = accelerator.unwrap_model(model).generate(**gen_kwargs) + + pred_tokens = generated_tokens[:, args.max_source_length :] + pred_tokens = accelerator.pad_across_processes(pred_tokens, dim=1, pad_index=tokenizer.pad_token_id) + gold_tokens = batch["labels"] + + if not args.pad_to_max_length: + # If we did not pad to max length, we need to pad the labels too + gold_tokens = accelerator.pad_across_processes( + batch["labels"], dim=1, pad_index=tokenizer.pad_token_id + ) + + pred_tokens, gold_tokens = accelerator.gather_for_metrics((pred_tokens, gold_tokens)) + pred_tokens, gold_tokens = pred_tokens.cpu().numpy(), gold_tokens.cpu().numpy() + + if isinstance(pred_tokens, tuple): + pred_tokens = pred_tokens[0] + decoded_pred = tokenizer.batch_decode(pred_tokens, skip_special_tokens=True) + decoded_gold = tokenizer.batch_decode(gold_tokens, skip_special_tokens=True) + + # Extract the numbers in sentences + accelerator.print(decoded_pred) + ans_pred_list += [extract_answer_number(sentence_pred) for sentence_pred in decoded_pred] + ans_gold_list += [extract_answer_number(sentence_gold) for sentence_gold in decoded_gold] + + accelerator.print(ans_pred_list) + accelerator.print(ans_gold_list) + accuracy = compute_accuracy(ans_gold_list, ans_pred_list) + + logger.info(f"epoch {epoch}: accuracy: {accuracy}") + + if args.with_tracking: + accelerator.log( + { + "accuracy": accuracy, + "train_loss": total_loss.item() / len(train_dataloader), + "epoch": epoch, + "step": completed_steps, + }, + step=completed_steps, + ) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + +def extract_answer_number(sentence: str) -> float: + sentence = sentence.replace(",", "") + pred = [s for s in re.findall(r"-?\d+\.?\d*", sentence)] + if not pred: + return float("inf") + segment = sentence.split("The final answer is ") + if len(segment) > 1: + pred_answer = segment[1] + pred_answer = [s for s in re.findall(r"-?\d+\.?\d*", pred_answer)] + if len(pred_answer) > 0: + pred_answer = pred_answer[0] + else: + pred_answer = float(pred[-1]) + else: + pred_answer = float(pred[-1]) + + if isinstance(pred_answer, str): + try: + pred_answer = float(pred_answer) + except ValueError: + pred_answer = float("inf") + return pred_answer + + +def compute_accuracy(pred: list, gold: list): + acc = 0.0 + for p, g in zip(pred, gold): + if p == g: + acc += 1 + + return acc / len(pred) + + +if __name__ == "__main__": + main() + +# example command + +# python train_gsm8k_llama.py \ +# --model_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-backbone \ +# --adapter_name_or_path LoftQ/Llama-2-7b-hf-bit4-rank64-adapters \ +# --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \ +# --learning_rate 1e-4 \ +# --seed 202 \ +# --dataset_name gsm8k \ +# --dataset_config main \ +# --pad_to_max_length \ +# --max_source_length 128 \ +# --max_target_length 256 \ +# --num_train_epochs 5 \ +# --per_device_train_batch_size 4 \ +# --per_device_eval_batch_size 4 \ +# --gradient_accumulation_steps 4 \ +# --with_tracking \ +# --report_to tensorboard diff --git a/src/peft/__init__.py b/src/peft/__init__.py index a3ce332f24..4d9380e697 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -48,6 +48,7 @@ AdaptionPromptConfig, AdaptionPromptModel, LoraConfig, + LoftQConfig, LoraModel, LoHaConfig, LoHaModel, diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index b357d47dc1..666e29d997 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -18,7 +18,7 @@ # limitations under the License. from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel -from .lora import LoraConfig, LoraModel +from .lora import LoraConfig, LoraModel, LoftQConfig from .loha import LoHaConfig, LoHaModel from .lokr import LoKrConfig, LoKrModel from .ia3 import IA3Config, IA3Model diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index d02bf2d948..ddc81d53cd 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -15,13 +15,13 @@ from peft.import_utils import is_bnb_4bit_available, is_bnb_available -from .config import LoraConfig +from .config import LoftQConfig, LoraConfig from .gptq import QuantLinear from .layer import Conv2d, Embedding, Linear, LoraLayer from .model import LoraModel -__all__ = ["LoraConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] +__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] if is_bnb_available(): diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 2412b61a1a..ab68ca0f96 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -20,6 +20,42 @@ from peft.utils import PeftType +@dataclass +class LoftQConfig: + """ + This is the sub-configuration class to store the configuration of a [`LoraModel`]. + + Args: + bits_pattern (`dict`): The mapping from layer names or regexp expression to bits which are different from the + default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}. + bits (`int`): Quantization bits for LoftQ. + iter (`int`): Alternating iterations for LoftQ. + fake (`bool`): True: use fp16/fp32; used for first time to save weights. False: use bitsandbytes 4bit linear + models. weights can't be saved. Recommend to set to True, save the weights and load the saved weights in 4 + bits. + """ + + bits_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to ranks which are different from " + "the default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}" + ) + }, + ) + loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"}) + loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"}) + loftq_fake: bool = field( + default=True, + metadata={ + "help": "True: use fp16/fp32; used for first time to save weights." + "False: use bitsandbytes 4bit linear models. weights can't be saved." + "Recommend to set to True, save the weights and load the saved weights in 4 bits." + }, + ) + + @dataclass class LoraConfig(PeftConfig): """ @@ -76,12 +112,13 @@ class LoraConfig(PeftConfig): "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) - init_lora_weights: bool = field( + init_lora_weights: Union[bool, str] = field( default=True, metadata={ "help": ( "Whether to initialize the weights of the Lora layers with their default initialization. Don't change " "this setting, except if you know exactly what you're doing." + "Pass `'loftq'` to use LoftQ initialization" ), }, ) @@ -117,6 +154,16 @@ class LoraConfig(PeftConfig): ) }, ) + # dict type is used when loading config.json + loftq_config: Union[LoftQConfig, dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The configuration of LoftQ. If this is not None, then LoftQ will be used to quantize the backbone " + "weights and initialize Lora layers." + ) + }, + ) def __post_init__(self): self.peft_type = PeftType.LORA @@ -130,3 +177,11 @@ def __post_init__(self): # if target_modules is a regex expression, then layers_pattern should be None if isinstance(self.target_modules, str) and self.layers_pattern is not None: raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") + + # handle init_lora_weights and loftq_config + if self.init_lora_weights == "loftq" and self.loftq_config is None: + raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") + + # convert loftq_config to dict + if self.loftq_config is not None and not isinstance(self.loftq_config, dict): + self.loftq_config = vars(self.loftq_config) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index c263053183..1f9e2cc00a 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -15,7 +15,7 @@ import math import warnings -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch import torch.nn as nn @@ -23,6 +23,7 @@ from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils.loftq_utils import loftq_init from peft.utils.other import transpose @@ -46,6 +47,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # Mark the weight as unmerged self._disable_adapters = False self.merged_adapters = [] + self.kwargs = kwargs base_layer = self.get_base_layer() if isinstance(base_layer, nn.Linear): @@ -83,7 +85,9 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False) self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: self.reset_lora_parameters(adapter_name) weight = getattr(self.get_base_layer(), "weight", None) @@ -115,7 +119,9 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: self.reset_lora_parameters(adapter_name) weight = getattr(base_layer, "weight", None) @@ -142,7 +148,9 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A) self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B) self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: self.reset_lora_parameters(adapter_name) base_layer = self.get_base_layer() @@ -162,6 +170,26 @@ def reset_lora_parameters(self, adapter_name): nn.init.zeros_(self.lora_embedding_A[adapter_name]) nn.init.normal_(self.lora_embedding_B[adapter_name]) + def loftq_init(self, adapter_name): + weight = self.get_base_layer().weight + kwargs = { + "num_bits": self.kwargs.get("loftq_bits", 4), + "reduced_rank": self.r[adapter_name], + "num_iter": self.kwargs.get("loftq_iter", 1), + "quiet": False, + } + + qweight, lora_A, lora_B = loftq_init(weight, **kwargs) + if adapter_name in self.lora_A.keys(): + # initialize A the same way as the default for nn.Linear and B to zero + self.lora_A[adapter_name].weight.data = lora_A + self.lora_B[adapter_name].weight.data = lora_B + if adapter_name in self.lora_embedding_A.keys(): + # initialize a the same way as the default for nn.linear and b to zero + self.lora_embedding_A[adapter_name].weight.data = lora_A + self.lora_embedding_B[adapter_name].weight.data = lora_B + self.get_base_layer().weight.data = qweight + def set_scale(self, adapter, scale): if adapter not in self.scaling: # Ignore the case where the adapter is not in the layer @@ -210,11 +238,11 @@ def __init__( lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) is_target_conv_1d_layer: bool = False, - init_lora_weights: bool = True, + init_lora_weights: Union[bool, str] = True, **kwargs, ) -> None: super().__init__() - LoraLayer.__init__(self, base_layer) + LoraLayer.__init__(self, base_layer, **kwargs) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name @@ -343,7 +371,7 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - init_lora_weights: bool = True, + init_lora_weights: Union[bool, str] = True, **kwargs, ) -> None: super().__init__() @@ -483,7 +511,7 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - init_lora_weights: bool = True, + init_lora_weights: Union[bool, str] = True, **kwargs, ) -> None: super().__init__() diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index a5b7735ce3..dac33f1618 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -285,8 +285,10 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): elif isinstance(target_base_layer, torch.nn.Embedding): embedding_kwargs = kwargs.copy() embedding_kwargs.pop("fan_in_fan_out", None) + embedding_kwargs.update(lora_config.loftq_config) new_module = Embedding(target, adapter_name, **embedding_kwargs) elif isinstance(target_base_layer, torch.nn.Conv2d): + kwargs.update(lora_config.loftq_config) new_module = Conv2d(target, adapter_name, **kwargs) elif isinstance(target_base_layer, torch.nn.Linear): if kwargs["fan_in_fan_out"]: @@ -295,6 +297,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + kwargs.update(lora_config.loftq_config) new_module = Linear(target, adapter_name, **kwargs) elif isinstance(target_base_layer, Conv1D): if not kwargs["fan_in_fan_out"]: @@ -303,6 +306,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): "Setting fan_in_fan_out to True." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + kwargs.update(lora_config.loftq_config) new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) else: raise ValueError( diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py new file mode 100644 index 0000000000..4002443dc7 --- /dev/null +++ b/src/peft/utils/loftq_utils.py @@ -0,0 +1,200 @@ +import torch +from scipy.stats import norm + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available + + +if is_bnb_available(): + import bitsandbytes as bnb + + +class NFQuantizer: + def __init__(self, num_bits=2, device="cuda", method="normal", block_size=64, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_bits = num_bits + self.device = device + self.method = method + self.block_size = block_size + if self.method == "normal": + self.norm_lookup_table = self.create_normal_map(num_bits=self.num_bits) + self.norm_lookup_table = self.norm_lookup_table.to(device) + elif self.method == "uniform": + self.norm_lookup_table = self.create_uniform_map(num_bits=self.num_bits) + self.norm_lookup_table = self.norm_lookup_table.to(device) + else: + raise NotImplementedError("Other quantization methods not supported yet.") + + @staticmethod + def create_uniform_map(symmetric=False, num_bits=4): + if symmetric: + # print("symmetric uniform quantization") + negative = torch.linspace(-1, 0, 2 ** (num_bits - 1)) + positive = torch.linspace(0, 1, 2 ** (num_bits - 1)) + table = torch.cat([negative, positive[1:]]) + else: + # print("asymmetric uniform quantization") + table = torch.linspace(-1, 1, 2**num_bits) + return table + + @staticmethod + def create_normal_map(offset=0.9677083, symmetric=False, num_bits=2): + variations = 2**num_bits + + if symmetric: + # print("symmetric NormalFloat") + v = norm.ppf(torch.linspace(1 - offset, offset, variations + 1)).tolist() + values = [] + for index in range(len(v) - 1): + values.append(0.5 * v[index] + 0.5 * v[index + 1]) + v = values + else: + # one more positive value, this is an asymmetric type + # print("asymmetric NormalFloat") + v1 = norm.ppf(torch.linspace(offset, 0.5, variations // 2 + 1)[:-1]).tolist() + # print(torch.linspace(offset, 0.5, 9)[:-1]) + # print(v1) + v2 = [0] + # v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, variations // 2)[:-1])).tolist() + # print(torch.linspace(offset, 0.5, 8)[:-1]) + # print(v3) + v = v1 + v2 + v3 + + values = torch.Tensor(v) + values = values.sort().values + values /= values.max() + # print(values) + return values + # assert values. + + def quantize_tensor(self, weight): + max_abs = torch.abs(weight).max() + weight_normed = weight / max_abs + + weight_normed_expanded = weight_normed.unsqueeze(-1) + + # Reshape L to have the same number of dimensions as X_expanded + L_reshaped = torch.tensor(self.norm_lookup_table).reshape(1, -1) + + # Calculate the absolute difference between X_expanded and L_reshaped + abs_diff = torch.abs(weight_normed_expanded - L_reshaped) + + # Find the index of the minimum absolute difference for each element + qweight = torch.argmin(abs_diff, dim=-1) + # print(min_index) + return qweight, max_abs + + def dequantize_tensor(self, qweight, max_abs): + qweight_flatten = qweight.flatten() + + weight_normed = self.norm_lookup_table[qweight_flatten] + weight = weight_normed * max_abs + + weight = weight.reshape(qweight.shape) + + return weight + + def quantize_block(self, weight): + assert len(weight.shape) == 2 and weight.shape[0] * weight.shape[1] % self.block_size == 0 + M, N = weight.shape + device = weight.device + + # Quantization + weight_flatten = weight.flatten() # (M*N, ) + weight_block = weight_flatten.reshape(-1, self.block_size) # (L, B), L = M * N / B + if self.method == "normal": + weight_max = weight_block.abs().max(dim=-1)[0] # (L, 1) + elif self.method == "uniform": + weight_max = weight_block.mean(dim=-1) + 2.5 * weight_block.std(dim=-1) + else: + raise NotImplementedError("Method not supported yet.") + weight_max = weight_max.unsqueeze(-1) + weight_divabs = weight_block / weight_max # (L, B) + weight_divabs = weight_divabs.unsqueeze(-1) # (L, B, 1) + L_reshaped = self.norm_lookup_table.reshape(1, -1) # (1, 2**K) + + abs_diff = torch.abs(weight_divabs - L_reshaped) # (L, B, 2**K) + qweight = torch.argmin(abs_diff, dim=-1) # (L, B) + + # Pack multiple k-bit into uint8 + qweight = qweight.reshape(-1, 8 // self.num_bits) + qweight_pack = torch.zeros((M * N // 8 * self.num_bits, 1), dtype=torch.uint8, device=device) + + # data format example: + # [1, 0, 3, 2] or [01, 00, 11, 10] -> [10110001], LIFO + for i in range(8 // self.num_bits): + qweight[:, i] = qweight[:, i] << i * self.num_bits + qweight_pack[:, 0] |= qweight[:, i] + + return qweight_pack, weight_max, weight.shape + + def dequantize_block(self, qweight, weight_max, weight_shape): + # unpack weight + device = qweight.device + weight = torch.zeros((qweight.shape[0], 8 // self.num_bits), dtype=torch.float32, device=device) + for i in range(8 // self.num_bits): + lookup_table_idx = qweight.to(torch.long) % 2**self.num_bits # get the most right 2 bits + lookup_table_idx = lookup_table_idx.to(torch.int) + weight[:, i] = self.norm_lookup_table[lookup_table_idx].squeeze() + qweight = qweight >> self.num_bits # right shift 2 bits of the original data + + weight_block = weight.reshape(-1, self.block_size) + weight = weight_block * weight_max + weight = weight.reshape(weight_shape) + + return weight + + +def _low_rank_decomposition(weight, reduced_rank=32): + """ + :param weight: The matrix to decompose, of shape (H, W) :param reduced_rank: the final rank :return: + """ + matrix_dimension = len(weight.size()) + assert matrix_dimension == 2, "Only Support 2D matrix" + + # Use SVD to decompose a matrix, default full_matrices is False to save parameters + U, S, Vh = torch.linalg.svd(weight, full_matrices=False) + + L = U @ (torch.sqrt(torch.diag(S)[:, 0:reduced_rank])) + R = torch.sqrt(torch.diag(S)[0:reduced_rank, :]) @ Vh + + return {"L": L, "R": R, "U": U, "S": S, "Vh": Vh, "reduced_rank": reduced_rank} + + +@torch.no_grad() +def loftq_init(weight, num_bits: int, reduced_rank: int, num_iter: int, quiet=False): + assert num_bits in [2, 4, 8], "Only support 2, 4, 8 bits quantization" + assert num_iter > 0, "Number of iterations must be greater than 0" + + out_feature, in_feature = weight.size() + device = weight.device + dtype = weight.dtype + if not quiet: + print( + f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} " + f"| Num Iter: {num_iter} | Num Bits: {num_bits}" + ) + if not is_bnb_4bit_available(): + quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64) + + res = weight.clone() + for i in range(num_iter): + torch.cuda.empty_cache() + # Quantization + if num_bits == 4 and is_bnb_4bit_available(): + qweight = bnb.nn.Params4bit(res.to("cpu"), requires_grad=False, compress_statistics=False).to(device) + dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state) + else: + quantized_weight, max_abs, shape = quantizer.quantize_block(res) + dequantized_weight = quantizer.dequantize_block(quantized_weight, max_abs, shape) + + res = weight - dequantized_weight + + # Decompose the residual by SVD + output = _low_rank_decomposition(res, reduced_rank=reduced_rank) + L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"] + res = weight - torch.mm(L, R) + + lora_A, lora_B = R, L + + return dequantized_weight.to(dtype), lora_A, lora_B From 667529aef9344f436cd53a01cb11aa26dcd02aa2 Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Mon, 20 Nov 2023 09:26:38 -0500 Subject: [PATCH 02/10] faster regex --- examples/loftq_finetuning/train_gsm8k_llama.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/loftq_finetuning/train_gsm8k_llama.py b/examples/loftq_finetuning/train_gsm8k_llama.py index d0ca714d45..05c2b33f11 100644 --- a/examples/loftq_finetuning/train_gsm8k_llama.py +++ b/examples/loftq_finetuning/train_gsm8k_llama.py @@ -814,15 +814,18 @@ def preprocess_function_test(examples): repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) +PATTERN_NUMBER = re.compile(r"-?\d+\.?\d*") + + def extract_answer_number(sentence: str) -> float: sentence = sentence.replace(",", "") - pred = [s for s in re.findall(r"-?\d+\.?\d*", sentence)] + pred = PATTERN_NUMBER.findall(sentence) if not pred: return float("inf") segment = sentence.split("The final answer is ") if len(segment) > 1: pred_answer = segment[1] - pred_answer = [s for s in re.findall(r"-?\d+\.?\d*", pred_answer)] + pred_answer = PATTERN_NUMBER.findall(pred_answer) if len(pred_answer) > 0: pred_answer = pred_answer[0] else: From 86ba774749502de803338c2dab5fa558a239710f Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Mon, 20 Nov 2023 11:22:44 -0500 Subject: [PATCH 03/10] add scipy to requirements --- requirements.txt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..dca857de32 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +accelerate +torch +safetensors +bitsandbytes +scipy +peft +transformers +tqdm +packaging +pytest +numpy +pyyaml +datasets +psutil +setuptools \ No newline at end of file From d01ccfdde955143688c22ce6fc75b242c8619923 Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Mon, 20 Nov 2023 14:12:10 -0500 Subject: [PATCH 04/10] change to nf4 --- .../loftq_finetuning/quantize_save_load.py | 36 +++++++++++++++++-- src/peft/utils/loftq_utils.py | 4 ++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/examples/loftq_finetuning/quantize_save_load.py b/examples/loftq_finetuning/quantize_save_load.py index 3966dea607..4d775e82ef 100644 --- a/examples/loftq_finetuning/quantize_save_load.py +++ b/examples/loftq_finetuning/quantize_save_load.py @@ -8,6 +8,7 @@ AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer, + BitsAndBytesConfig, ) from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model @@ -171,11 +172,40 @@ def quantize_and_save(): def load_loftq(base_model_path, lora_adapter_path): if "llama" in base_model_path.lower(): - model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map="auto", load_in_4bit=True) + model = AutoModelForCausalLM.from_pretrained( + base_model_path, + device_map="auto", + low_cpu_mem_usage=True, + load_in_4bit=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + ), + ) elif "bart" in base_model_path.lower(): - model = AutoModelForSeq2SeqLM.from_pretrained(base_model_path, device_map="auto", load_in_4bit=True) + model = AutoModelForSeq2SeqLM.from_pretrained( + base_model_path, + device_map="auto", + low_cpu_mem_usage=True, + load_in_4bit=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + ), + ) elif "deberta" in base_model_path.lower(): - model = AutoModelForSequenceClassification.from_pretrained(base_model_path, load_in_4bit=True) + model = AutoModelForSequenceClassification.from_pretrained( + base_model_path, + low_cpu_mem_usage=True, + load_in_4bit=True, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_type="nf4", + ), + ) else: raise NotImplementedError("Other models not supported yet.") diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index 4002443dc7..728670c940 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -182,7 +182,9 @@ def loftq_init(weight, num_bits: int, reduced_rank: int, num_iter: int, quiet=Fa torch.cuda.empty_cache() # Quantization if num_bits == 4 and is_bnb_4bit_available(): - qweight = bnb.nn.Params4bit(res.to("cpu"), requires_grad=False, compress_statistics=False).to(device) + qweight = bnb.nn.Params4bit( + res.to("cpu"), requires_grad=False, compress_statistics=False, quant_type="nf4" + ).to(device) dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state) else: quantized_weight, max_abs, shape = quantizer.quantize_block(res) From e04fab2979b10443d00614bfbdbb7cc2a40a6633 Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Mon, 20 Nov 2023 23:17:37 -0500 Subject: [PATCH 05/10] to float32 when svd change peft model path on HF auto float16/32 for bnb --- examples/loftq_finetuning/train_gsm8k_llama.py | 6 ++++-- src/peft/tuners/lora/layer.py | 1 - src/peft/utils/loftq_utils.py | 15 +++++++++------ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/loftq_finetuning/train_gsm8k_llama.py b/examples/loftq_finetuning/train_gsm8k_llama.py index 05c2b33f11..50236c804f 100644 --- a/examples/loftq_finetuning/train_gsm8k_llama.py +++ b/examples/loftq_finetuning/train_gsm8k_llama.py @@ -467,6 +467,7 @@ def main(): load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=config.torch_dtype, ), ) else: @@ -477,8 +478,9 @@ def main(): # Peft Model # ########################## if args.adapter_name_or_path is None: - args.adapter_name_or_path = args.model_name_or_path - model = PeftModel.from_pretrained(model, args.adapter_name_or_path, is_trainable=True) + model = PeftModel.from_pretrained(model, args.model_name_or_path, subfolder="loftq_init", is_trainable=True) + else: + model = PeftModel.from_pretrained(model, args.adapter_name_or_path, is_trainable=True) model.print_trainable_parameters() # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 1f9e2cc00a..964b1e6edc 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -176,7 +176,6 @@ def loftq_init(self, adapter_name): "num_bits": self.kwargs.get("loftq_bits", 4), "reduced_rank": self.r[adapter_name], "num_iter": self.kwargs.get("loftq_iter", 1), - "quiet": False, } qweight, lora_A, lora_B = loftq_init(weight, **kwargs) diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index 728670c940..352b3c4638 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -1,3 +1,5 @@ +import logging + import torch from scipy.stats import norm @@ -162,21 +164,22 @@ def _low_rank_decomposition(weight, reduced_rank=32): @torch.no_grad() -def loftq_init(weight, num_bits: int, reduced_rank: int, num_iter: int, quiet=False): +def loftq_init(weight, num_bits: int, reduced_rank: int, num_iter: int): assert num_bits in [2, 4, 8], "Only support 2, 4, 8 bits quantization" assert num_iter > 0, "Number of iterations must be greater than 0" out_feature, in_feature = weight.size() device = weight.device dtype = weight.dtype - if not quiet: - print( - f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} " - f"| Num Iter: {num_iter} | Num Bits: {num_bits}" - ) + + logging.info( + f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} " + f"| Num Iter: {num_iter} | Num Bits: {num_bits}" + ) if not is_bnb_4bit_available(): quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64) + weight = weight.to(torch.float32) res = weight.clone() for i in range(num_iter): torch.cuda.empty_cache() From 6aba235fbbe7f89cd0841f63d6e6a14416edbe55 Mon Sep 17 00:00:00 2001 From: yxli2123 <69247082+yxli2123@users.noreply.github.com> Date: Sun, 26 Nov 2023 11:37:50 -0500 Subject: [PATCH 06/10] Update examples/loftq_finetuning/README.md Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> --- examples/loftq_finetuning/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/loftq_finetuning/README.md b/examples/loftq_finetuning/README.md index 02a005a6a5..16a755fd37 100644 --- a/examples/loftq_finetuning/README.md +++ b/examples/loftq_finetuning/README.md @@ -2,7 +2,7 @@ ## Introduction -LoftQ provides better initialization for LoRA adaptors A and B, +LoftQ provides better initialization for LoRA adapters A and B, and the Quantization of pre-trained weights W. ## Quantization From d4523c893454d969137f945c75df92915cecaebb Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Sun, 26 Nov 2023 11:53:29 -0500 Subject: [PATCH 07/10] import scipy at loftq_init --- examples/loftq_finetuning/quantize_save_load.py | 4 ++-- src/peft/tuners/lora/config.py | 17 ----------------- src/peft/tuners/lora/layer.py | 2 +- src/peft/utils/loftq_utils.py | 6 +++++- 4 files changed, 8 insertions(+), 21 deletions(-) diff --git a/examples/loftq_finetuning/quantize_save_load.py b/examples/loftq_finetuning/quantize_save_load.py index 4d775e82ef..e73a331395 100644 --- a/examples/loftq_finetuning/quantize_save_load.py +++ b/examples/loftq_finetuning/quantize_save_load.py @@ -129,13 +129,13 @@ def quantize_and_save(): raise NotImplementedError("Other models not supported yet.") # Config of LoftQ - loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter, loftq_fake=True) + loftq_config = LoftQConfig(loftq_bits=args.bits, loftq_iter=args.iter) lora_config = LoraConfig( task_type=task_type, inference_mode=True, r=args.rank, - lora_alpha=args.rank, + lora_alpha=16 if task_type is TaskType.CAUSAL_LM else args.rank, lora_dropout=0.1, target_modules=target_modules, init_lora_weights="loftq", diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index ab68ca0f96..ade4b7892f 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -35,25 +35,8 @@ class LoftQConfig: bits. """ - bits_pattern: Optional[dict] = field( - default_factory=dict, - metadata={ - "help": ( - "The mapping from layer names or regexp expression to ranks which are different from " - "the default bits specified by `bits`. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 2`}" - ) - }, - ) loftq_bits: int = field(default=4, metadata={"help": "Quantization bits for LoftQ"}) loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"}) - loftq_fake: bool = field( - default=True, - metadata={ - "help": "True: use fp16/fp32; used for first time to save weights." - "False: use bitsandbytes 4bit linear models. weights can't be saved." - "Recommend to set to True, save the weights and load the saved weights in 4 bits." - }, - ) @dataclass diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 964b1e6edc..57591ff209 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -23,7 +23,6 @@ from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer -from peft.utils.loftq_utils import loftq_init from peft.utils.other import transpose @@ -171,6 +170,7 @@ def reset_lora_parameters(self, adapter_name): nn.init.normal_(self.lora_embedding_B[adapter_name]) def loftq_init(self, adapter_name): + from peft.utils.loftq_utils import loftq_init weight = self.get_base_layer().weight kwargs = { "num_bits": self.kwargs.get("loftq_bits", 4), diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index 352b3c4638..b33ebada86 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -1,7 +1,11 @@ import logging import torch -from scipy.stats import norm +try: + from scipy.stats import norm +except ImportError: + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + from peft.import_utils import is_bnb_4bit_available, is_bnb_available From fe9e84a77fdf85a59745d5ed9897920f9ae24721 Mon Sep 17 00:00:00 2001 From: yxli2123 <69247082+yxli2123@users.noreply.github.com> Date: Tue, 28 Nov 2023 10:21:35 -0500 Subject: [PATCH 08/10] Update examples/loftq_finetuning/README.md Co-authored-by: Benjamin Bossan --- examples/loftq_finetuning/README.md | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/loftq_finetuning/README.md b/examples/loftq_finetuning/README.md index 16a755fd37..86131ebada 100644 --- a/examples/loftq_finetuning/README.md +++ b/examples/loftq_finetuning/README.md @@ -35,13 +35,15 @@ from transformers import AutoModelForCausalLM from peft import PeftModel -base_model = AutoModelForCausalLM.from_pretrained(os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-backbone"), - load_in_4bit=True, - ) -peft_model = PeftModel.from_pretrained(base_model, - os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-adapters", - is_trainable=True), - ) +base_model = AutoModelForCausalLM.from_pretrained( + os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-backbone"), + load_in_4bit=True, +) +peft_model = PeftModel.from_pretrained( + base_model, + os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-adapters", + is_trainable=True, +) ``` We also provide an example to fine-tune LoftQ on GSM8K. From cf68839639fe9361c19acf8b7088bbaacc80ad9e Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Tue, 28 Nov 2023 10:43:22 -0500 Subject: [PATCH 09/10] add license, import scipy at post_init, change assert to ValueError msg, edit example docs --- examples/loftq_finetuning/README.md | 11 ++-- .../loftq_finetuning/quantize_save_load.py | 38 ++++++++---- .../loftq_finetuning/train_gsm8k_llama.py | 12 +--- src/peft/tuners/lora/config.py | 10 +++- src/peft/tuners/lora/layer.py | 1 + src/peft/utils/loftq_utils.py | 60 ++++++++++++------- 6 files changed, 80 insertions(+), 52 deletions(-) diff --git a/examples/loftq_finetuning/README.md b/examples/loftq_finetuning/README.md index 86131ebada..726f544e85 100644 --- a/examples/loftq_finetuning/README.md +++ b/examples/loftq_finetuning/README.md @@ -21,8 +21,8 @@ python quantize_save_load.py \ - `HF_TOKEN` is the token used to access to [LLAMA models](https://huggingface.co/meta-llama). - `quantize_and_save()` function will quantize the backbone and initialize LoRA adapters. -It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank-backbone`, -and the LoRA adapters are at `Llama-2-7b-hf-4bit-16rank-adapters`. +It creates 2 folders under `$save_dir`. The quantized backbone is at `Llama-2-7b-hf-4bit-16rank`, +and the LoRA adapters are at the sub-folder `Llama-2-7b-hf-4bit-16rank/loftq_init`. ## Fine-tuning @@ -36,12 +36,12 @@ from peft import PeftModel base_model = AutoModelForCausalLM.from_pretrained( - os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-backbone"), + os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank"), load_in_4bit=True, ) peft_model = PeftModel.from_pretrained( base_model, - os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank-adapters", + os.path.join(args.save_dir, "Llama-2-7b-hf-4bit-16rank", "loftq_init"), is_trainable=True, ) ``` @@ -51,8 +51,7 @@ We load the quantized backbone and LoRA adapters from the [LoftQ Huggingface hub ```sh python train_gsm8k_llama.py \ - --model_name_or_path LoftQ/Llama-2-7b-hf-6bit-64rank-backbone \ - --adapter_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank-adapters \ + --model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \ --output_dir exp_results/gsm8k/llama-2-7b/bit4-rank64/lr3e-4 \ --learning_rate 3e-4 \ --seed 202 \ diff --git a/examples/loftq_finetuning/quantize_save_load.py b/examples/loftq_finetuning/quantize_save_load.py index e73a331395..3c47fa45cd 100644 --- a/examples/loftq_finetuning/quantize_save_load.py +++ b/examples/loftq_finetuning/quantize_save_load.py @@ -1,3 +1,18 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import os @@ -109,18 +124,20 @@ def quantize_and_save(): args = arg_parse() # Download weights and configure LoRA - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token) - if "llama" in args.model_name_or_path.lower(): - model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, token=args.token, trust_remote_code=True) + if any(name in args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon"]): + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, token=args.token, trust_remote_code=True, device_map="auto" + ) task_type = TaskType.CAUSAL_LM target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"] - elif "bart" in args.model_name_or_path.lower(): + elif any(name in args.model_name_or_path.lower() for name in ["bart", "t5"]): model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, token=args.token, device_map="auto") task_type = TaskType.SEQ_2_SEQ_LM target_modules = ["q_proj", "k_proj", "v_proj", "fc1", "fc2", "out_proj"] - elif "deberta" in args.model_name_or_path.lower(): + elif any(name in args.model_name_or_path.lower() for name in ["deberta", "roberta", "bert"]): model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, token=args.token) model = model.cuda() task_type = TaskType.SEQ_CLS @@ -148,8 +165,8 @@ def quantize_and_save(): # Save LoftQ model model_name = args.model_name_or_path.split("/")[-1] + f"-{args.bits}bit" + f"-{args.rank}rank" - base_model_dir = os.path.join(args.save_dir, model_name + "-backbone") - lora_model_dir = os.path.join(args.save_dir, model_name + "-adapters") + base_model_dir = os.path.join(args.save_dir, model_name) + lora_model_dir = os.path.join(args.save_dir, model_name, "loft_init") # save lora adapters first lora_model.base_model.peft_config[ @@ -171,19 +188,18 @@ def quantize_and_save(): def load_loftq(base_model_path, lora_adapter_path): - if "llama" in base_model_path.lower(): + if any(name in base_model_path.lower() for name in ["llama", "mistral", "falcon"]): model = AutoModelForCausalLM.from_pretrained( base_model_path, device_map="auto", low_cpu_mem_usage=True, - load_in_4bit=True, quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, bnb_4bit_quant_type="nf4", ), ) - elif "bart" in base_model_path.lower(): + elif any(name in base_model_path.lower() for name in ["bart", "t5"]): model = AutoModelForSeq2SeqLM.from_pretrained( base_model_path, device_map="auto", @@ -195,7 +211,7 @@ def load_loftq(base_model_path, lora_adapter_path): bnb_4bit_quant_type="nf4", ), ) - elif "deberta" in base_model_path.lower(): + elif any(name in base_model_path.lower() for name in ["deberta", "roberta", "bert"]): model = AutoModelForSequenceClassification.from_pretrained( base_model_path, low_cpu_mem_usage=True, diff --git a/examples/loftq_finetuning/train_gsm8k_llama.py b/examples/loftq_finetuning/train_gsm8k_llama.py index 50236c804f..e8c3580d2e 100644 --- a/examples/loftq_finetuning/train_gsm8k_llama.py +++ b/examples/loftq_finetuning/train_gsm8k_llama.py @@ -1,6 +1,5 @@ -#!/usr/bin/env python # coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) -on a text file or a dataset without using HuggingFace Trainer. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=text-generation -""" -# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import argparse import copy @@ -462,7 +453,6 @@ def main(): from_tf=bool(".ckpt" in args.model_name_or_path), config=config, low_cpu_mem_usage=True, - load_in_4bit=True, quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index ade4b7892f..1270d316ef 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from dataclasses import dataclass, field from typing import List, Optional, Union @@ -162,8 +161,13 @@ def __post_init__(self): raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") # handle init_lora_weights and loftq_config - if self.init_lora_weights == "loftq" and self.loftq_config is None: - raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") + if self.init_lora_weights == "loftq": + import importlib + + if not importlib.util.find_spec("scipy"): + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + if self.loftq_config is None: + raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") # convert loftq_config to dict if self.loftq_config is not None and not isinstance(self.loftq_config, dict): diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 57591ff209..f0d00c3ea2 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -171,6 +171,7 @@ def reset_lora_parameters(self, adapter_name): def loftq_init(self, adapter_name): from peft.utils.loftq_utils import loftq_init + weight = self.get_base_layer().weight kwargs = { "num_bits": self.kwargs.get("loftq_bits", 4), diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index b33ebada86..81ff1e2c34 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -1,11 +1,25 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Reference code: https://github.com/yxli2123/LoftQ/blob/main/utils.py +# Reference paper: https://arxiv.org/abs/2310.08659 + import logging +from typing import Union import torch -try: - from scipy.stats import norm -except ImportError: - raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") - from peft.import_utils import is_bnb_4bit_available, is_bnb_available @@ -44,10 +58,13 @@ def create_uniform_map(symmetric=False, num_bits=4): @staticmethod def create_normal_map(offset=0.9677083, symmetric=False, num_bits=2): - variations = 2**num_bits + try: + from scipy.stats import norm + except ImportError: + raise ImportError("The required package 'scipy' is not installed. Please install it to continue.") + variations = 2**num_bits if symmetric: - # print("symmetric NormalFloat") v = norm.ppf(torch.linspace(1 - offset, offset, variations + 1)).tolist() values = [] for index in range(len(v) - 1): @@ -55,23 +72,15 @@ def create_normal_map(offset=0.9677083, symmetric=False, num_bits=2): v = values else: # one more positive value, this is an asymmetric type - # print("asymmetric NormalFloat") v1 = norm.ppf(torch.linspace(offset, 0.5, variations // 2 + 1)[:-1]).tolist() - # print(torch.linspace(offset, 0.5, 9)[:-1]) - # print(v1) v2 = [0] - # v2 = [0]*(256-15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, variations // 2)[:-1])).tolist() - # print(torch.linspace(offset, 0.5, 8)[:-1]) - # print(v3) v = v1 + v2 + v3 values = torch.Tensor(v) values = values.sort().values values /= values.max() - # print(values) return values - # assert values. def quantize_tensor(self, weight): max_abs = torch.abs(weight).max() @@ -87,7 +96,6 @@ def quantize_tensor(self, weight): # Find the index of the minimum absolute difference for each element qweight = torch.argmin(abs_diff, dim=-1) - # print(min_index) return qweight, max_abs def dequantize_tensor(self, qweight, max_abs): @@ -101,7 +109,14 @@ def dequantize_tensor(self, qweight, max_abs): return weight def quantize_block(self, weight): - assert len(weight.shape) == 2 and weight.shape[0] * weight.shape[1] % self.block_size == 0 + if len(weight.shape) != 2: + raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.") + if weight.shape[0] * weight.shape[1] % self.block_size != 0: + raise ValueError( + f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) " + f"is not dividable by block size {self.block_size}." + ) + M, N = weight.shape device = weight.device @@ -156,7 +171,8 @@ def _low_rank_decomposition(weight, reduced_rank=32): :param weight: The matrix to decompose, of shape (H, W) :param reduced_rank: the final rank :return: """ matrix_dimension = len(weight.size()) - assert matrix_dimension == 2, "Only Support 2D matrix" + if matrix_dimension != 2: + raise ValueError(f"Only support 2D matrix, but your input has {matrix_dimension} dimensions.") # Use SVD to decompose a matrix, default full_matrices is False to save parameters U, S, Vh = torch.linalg.svd(weight, full_matrices=False) @@ -168,9 +184,11 @@ def _low_rank_decomposition(weight, reduced_rank=32): @torch.no_grad() -def loftq_init(weight, num_bits: int, reduced_rank: int, num_iter: int): - assert num_bits in [2, 4, 8], "Only support 2, 4, 8 bits quantization" - assert num_iter > 0, "Number of iterations must be greater than 0" +def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, reduced_rank: int, num_iter=1): + if num_bits not in [2, 4, 8]: + raise ValueError("Only support 2, 4, 8 bits quantization") + if num_iter <= 0: + raise ValueError("Number of iterations must be greater than 0") out_feature, in_feature = weight.size() device = weight.device From aaff3154fc3df1e46a7b4a9c53a21766791d03ca Mon Sep 17 00:00:00 2001 From: "Yixiao Li, Macbook Air" Date: Wed, 29 Nov 2023 10:48:06 -0500 Subject: [PATCH 10/10] add loftq to Literal --- src/peft/tuners/lora/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 146460a3f3..0dcca5c1e6 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -97,7 +97,7 @@ class LoraConfig(PeftConfig): "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) - init_lora_weights: bool | Literal["gaussian"] = field( + init_lora_weights: bool | Literal["gaussian", "loftq"] = field( default=True, metadata={ "help": (