diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index 925ade2e5a..5e86834d53 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -40,6 +40,12 @@ SHELL ["/bin/bash", "-c"] RUN source activate peft && \ python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq +# Add autoawq for quantization testing +RUN source activate peft && \ + python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.1/autoawq-0.2.1-cp38-cp38-linux_x86_64.whl +RUN source activate peft && \ + python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl + # Install apt libs RUN apt-get update && \ apt-get install -y curl git wget && \ diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 677cf23ed6..b3ff72bc05 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -63,3 +63,7 @@ def is_torch_tpu_available(check_device=True): return False return True return False + + +def is_auto_awq_available(): + return importlib.util.find_spec("awq") is not None diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py new file mode 100644 index 0000000000..b3f5bf3978 --- /dev/null +++ b/src/peft/tuners/lora/awq.py @@ -0,0 +1,108 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib.metadata as importlib_metadata +from typing import Any, Optional + +import packaging.version +import torch + +from peft.import_utils import is_auto_awq_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + + +if is_auto_awq_available(): + from awq.modules.linear import WQLinear_GEMM + + +class AwqLoraLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor): + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_awq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_auto_awq_available() and isinstance(target_base_layer, WQLinear_GEMM): + # Raise the error only at the dispatch level + AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") + version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) + + if AUTOAWQ_MINIMUM_VERSION > version_autoawq: + raise ImportError( + f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " + f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." + ) + + new_module = AwqLoraLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight + + return new_module diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 0dfe97966c..5c920be1aa 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -66,6 +66,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): # Megatron ColumnParallelLinear,RowParallelLinear in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": + # Awq layers + in_features, out_features = base_layer.in_features, base_layer.out_features else: raise ValueError(f"Unsupported layer type {type(base_layer)}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 3b12332149..ca64a8fe21 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -38,6 +38,7 @@ ) from peft.utils.merge_utils import dare_linear, dare_ties, task_arithmetic, ties +from .awq import dispatch_awq from .config import LoraConfig from .gptq import dispatch_gptq from .layer import Conv2d, LoraLayer, dispatch_default @@ -156,9 +157,11 @@ def _create_and_replace( "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), } - quantization_config = get_quantization_config(self.model, method="gptq") - if quantization_config is not None: - kwargs["gptq_quantization_config"] = quantization_config + quant_methods = ["gptq", "awq"] + for quant_method in quant_methods: + quantization_config = get_quantization_config(self.model, method=quant_method) + if quantization_config is not None: + kwargs[f"{quant_method}_quantization_config"] = quantization_config # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it from peft.tuners.adalora import AdaLoraLayer @@ -244,7 +247,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): dispatchers.append(dispatch_bnb_4bit) - dispatchers.extend([dispatch_gptq, dispatch_megatron, dispatch_default]) + dispatchers.extend([dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default]) new_module = None for dispatcher in dispatchers: diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index a73a1c6461..943f3f54de 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -53,6 +53,7 @@ from peft.utils import SAFETENSORS_WEIGHTS_NAME from .testing_utils import ( + require_auto_awq, require_auto_gptq, require_bitsandbytes, require_optimum, @@ -1378,3 +1379,146 @@ def test_model_loaded_in_float16_working(self): data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), ) trainer.train() + + +@require_torch_gpu +@require_auto_awq +class PeftAwqGPUTests(unittest.TestCase): + r""" + Awq + peft tests + """ + + def setUp(self): + self.causal_lm_model_id = "peft-internal-testing/opt-125m-awq" + self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + + def tearDown(self): + r""" + Efficient mechanism to free GPU memory after each test. Based on + https://github.com/huggingface/transformers/issues/21094 + """ + gc.collect() + torch.cuda.empty_cache() + + def _check_inference_finite(self, model, batch): + # try inference without Trainer class + training = model.training + model.eval() + output = model(**batch.to(model.device)) + assert torch.isfinite(output.logits).all() + model.train(training) + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_awq(self): + r""" + Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set + correctly. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + ) + + model = prepare_model_for_kbit_training(model) + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + # TODO: deal correctly with this case in transformers + model._is_quantized_training_enabled = True + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + logging_steps=1, + output_dir=tmp_dir, + fp16=True, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + @require_torch_multi_gpu + def test_causal_lm_training_multi_gpu(self): + r""" + Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set + correctly. + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + ) + + assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count())) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None diff --git a/tests/testing_utils.py b/tests/testing_utils.py index bd7047e31b..f6063505e9 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -18,7 +18,7 @@ import pytest import torch -from peft.import_utils import is_auto_gptq_available, is_optimum_available +from peft.import_utils import is_auto_awq_available, is_auto_gptq_available, is_optimum_available def require_torch_gpu(test_case): @@ -61,6 +61,13 @@ def require_auto_gptq(test_case): return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) +def require_auto_awq(test_case): + """ + Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed. + """ + return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case) + + def require_optimum(test_case): """ Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed.