From 35b1dec17ceae11b5ea73c911003e9c3d714e49f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 26 Jan 2024 18:56:43 +0100 Subject: [PATCH 01/22] add awq suppot in PEFT --- src/peft/import_utils.py | 19 +++++++ src/peft/tuners/lora/awq.py | 97 +++++++++++++++++++++++++++++++++++ src/peft/tuners/lora/layer.py | 3 ++ src/peft/tuners/lora/model.py | 11 ++-- tests/test_gpu_examples.py | 86 +++++++++++++++++++++++++++++++ tests/testing_utils.py | 9 +++- 6 files changed, 220 insertions(+), 5 deletions(-) create mode 100644 src/peft/tuners/lora/awq.py diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index f82d2238f1..eb50ebd41b 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -64,3 +64,22 @@ def is_torch_tpu_available(check_device=True): return False return True return False + + +def is_auto_awq_available(): + if importlib.util.find_spec("awq") is not None: + import awq + + awq_version = awq.__version__ + + # TODO: change it to 0.2.0 + AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.8") + + version_autoawq = packaging.version.parse(awq_version) + if AUTOAWQ_MINIMUM_VERSION >= version_autoawq: + return True + else: + raise ImportError( + f"Found an incompatible version of auto-gptq. Found version {version_autoawq}, " + f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported" + ) diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py new file mode 100644 index 0000000000..d32b4d9a5a --- /dev/null +++ b/src/peft/tuners/lora/awq.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from peft.import_utils import is_auto_awq_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + + +if is_auto_awq_available(): + from awq.modules.linear import WQLinear_GEMM as AWQ_WQLinear_GEMM + + +class WQLinear_GEMM(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor): + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result += output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_awq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, AWQ_WQLinear_GEMM): + new_module = WQLinear_GEMM(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight + + return new_module diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 92df6c17db..0e5310dad3 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -67,6 +67,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): # Megatron ColumnParallelLinear,RowParallelLinear in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": + # Awq layers + in_features, out_features = base_layer.in_features, base_layer.out_features else: raise ValueError(f"Unsupported layer type {type(base_layer)}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 3ed409b10e..58fa51c51b 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -38,6 +38,7 @@ get_quantization_config, ) +from .awq import dispatch_awq from .config import LoraConfig from .gptq import dispatch_gptq from .layer import Conv2d, LoraLayer, dispatch_default @@ -156,9 +157,11 @@ def _create_and_replace( "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), } - quantization_config = get_quantization_config(self.model, method="gptq") - if quantization_config is not None: - kwargs["gptq_quantization_config"] = quantization_config + quant_methods = ["gptq", "awq"] + for quant_method in quant_methods: + quantization_config = get_quantization_config(self.model, method=quant_method) + if quantization_config is not None: + kwargs[f"{quant_method}_quantization_config"] = quantization_config # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it from peft.tuners.adalora import AdaLoraLayer @@ -244,7 +247,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): dispatchers.append(dispatch_bnb_4bit) - dispatchers.extend([dispatch_gptq, dispatch_megatron, dispatch_default]) + dispatchers.extend([dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default]) new_module = None for dispatcher in dispatchers: diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index a23d4cbb56..a62c2ead71 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -54,6 +54,7 @@ from peft.utils import SAFETENSORS_WEIGHTS_NAME from .testing_utils import ( + require_auto_awq, require_auto_gptq, require_bitsandbytes, require_optimum, @@ -1378,3 +1379,88 @@ def test_model_loaded_in_float16_working(self): data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), ) trainer.train() + + +@require_torch_gpu +@require_auto_awq +class PeftAwqGPUTests(unittest.TestCase): + r""" + Awq + peft tests + """ + + def setUp(self): + + self.causal_lm_model_id = "ybelkada/opt-125m-awq" + self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + + def tearDown(self): + r""" + Efficient mechanism to free GPU memory after each test. Based on + https://github.com/huggingface/transformers/issues/21094 + """ + gc.collect() + torch.cuda.empty_cache() + + def _check_inference_finite(self, model, batch): + # try inference without Trainer class + training = model.training + model.eval() + output = model(**batch.to(model.device)) + self.assertTrue(torch.isfinite(output.logits).all()) + model.train(training) + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_awq(self): + r""" + Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set + correctly. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + device_map="auto", + ) + + model = prepare_model_for_kbit_training(model) + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + # TODO: deal correctly with this case in transformers + model._is_quantized_training_enabled = True + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) + self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)) + + # assert loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index cb5b55e877..fa0f277614 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -19,7 +19,7 @@ import pytest import torch -from peft.import_utils import is_auto_gptq_available, is_optimum_available +from peft.import_utils import is_auto_awq_available, is_auto_gptq_available, is_optimum_available def require_torch_gpu(test_case): @@ -62,6 +62,13 @@ def require_auto_gptq(test_case): return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case) +def require_auto_awq(test_case): + """ + Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed. + """ + return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case) + + def require_optimum(test_case): """ Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed. From c08b98efa811b612bd1bfce535cc49e83561b136 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 26 Jan 2024 19:00:09 +0100 Subject: [PATCH 02/22] fix --- src/peft/import_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index eb50ebd41b..e23910f333 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -68,14 +68,10 @@ def is_torch_tpu_available(check_device=True): def is_auto_awq_available(): if importlib.util.find_spec("awq") is not None: - import awq - - awq_version = awq.__version__ - # TODO: change it to 0.2.0 AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.8") + version_autoawq = packaging.version.parse(packaging.version.parse(importlib_metadata.version("autoawq"))) - version_autoawq = packaging.version.parse(awq_version) if AUTOAWQ_MINIMUM_VERSION >= version_autoawq: return True else: From 74900573807e1bd4609f2e2348b111cd21526652 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 26 Jan 2024 19:09:11 +0100 Subject: [PATCH 03/22] fux --- src/peft/import_utils.py | 19 ++++++++++--------- src/peft/tuners/lora/awq.py | 1 + 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index e23910f333..664e0a98a3 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -69,13 +69,14 @@ def is_torch_tpu_available(check_device=True): def is_auto_awq_available(): if importlib.util.find_spec("awq") is not None: # TODO: change it to 0.2.0 - AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.8") - version_autoawq = packaging.version.parse(packaging.version.parse(importlib_metadata.version("autoawq"))) + # AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.7") + # version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) + # if AUTOAWQ_MINIMUM_VERSION >= version_autoawq: + # return True + # else: + # raise ImportError( + # f"Found an incompatible version of auto-gptq. Found version {version_autoawq}, " + # f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported" + # ) - if AUTOAWQ_MINIMUM_VERSION >= version_autoawq: - return True - else: - raise ImportError( - f"Found an incompatible version of auto-gptq. Found version {version_autoawq}, " - f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported" - ) + return True diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index d32b4d9a5a..d5c1ed46cf 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -44,6 +44,7 @@ def __init__( # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter # for backwards compatibility self.quant_linear_module = base_layer + self._active_adapter = adapter_name self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) From 9283800423db33ae1ddd5fe6df67765b866f146e Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 26 Jan 2024 22:16:47 +0100 Subject: [PATCH 04/22] Update src/peft/tuners/lora/awq.py --- src/peft/tuners/lora/awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index d5c1ed46cf..f205523a35 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -71,7 +71,7 @@ def forward(self, x: torch.Tensor): if requires_conversion: output = output.to(expected_dtype) output = output * scaling - result += output + result = result + output return result def __repr__(self) -> str: From b925c17f52dc42ec5fab0a4b4b6c783ae2b78ee1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 13 Feb 2024 02:12:01 +0100 Subject: [PATCH 05/22] style & fix tests --- src/peft/tuners/lora/awq.py | 3 +-- tests/test_gpu_examples.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index f205523a35..684a241502 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -91,7 +90,7 @@ def dispatch_awq( else: target_base_layer = target - if isinstance(target_base_layer, AWQ_WQLinear_GEMM): + if is_auto_awq_available() and isinstance(target_base_layer, AWQ_WQLinear_GEMM): new_module = WQLinear_GEMM(target, adapter_name, **kwargs) target.qweight = target_base_layer.qweight diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index ceae3ec6f3..2b922ef6eb 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1389,7 +1389,6 @@ class PeftAwqGPUTests(unittest.TestCase): """ def setUp(self): - self.causal_lm_model_id = "ybelkada/opt-125m-awq" self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) From b7ac85f183a7c4927c22014408b6fa4f0d402ac5 Mon Sep 17 00:00:00 2001 From: s4rduk4r Date: Tue, 13 Feb 2024 02:17:15 +0100 Subject: [PATCH 06/22] forward contrib credits from PR14084 From 02d6eca8f0f17db14c2461c5f8b99b435a49a2b3 Mon Sep 17 00:00:00 2001 From: s4rduk4r Date: Tue, 13 Feb 2024 02:17:27 +0100 Subject: [PATCH 07/22] forward contrib credits from autoawq PR From 616aefe598d5a6a83ee408fbb8603b9dde7fd04b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Feb 2024 01:34:48 +0100 Subject: [PATCH 08/22] change name --- src/peft/tuners/lora/awq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index 684a241502..d40473346e 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -25,7 +25,7 @@ from awq.modules.linear import WQLinear_GEMM as AWQ_WQLinear_GEMM -class WQLinear_GEMM(torch.nn.Module, LoraLayer): +class AwqLoraLinear(torch.nn.Module, LoraLayer): def __init__( self, base_layer, @@ -91,7 +91,7 @@ def dispatch_awq( target_base_layer = target if is_auto_awq_available() and isinstance(target_base_layer, AWQ_WQLinear_GEMM): - new_module = WQLinear_GEMM(target, adapter_name, **kwargs) + new_module = AwqLoraLinear(target, adapter_name, **kwargs) target.qweight = target_base_layer.qweight return new_module From c05feec5ff2b85323013e54c5eaaac10b9a53900 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Feb 2024 01:35:36 +0100 Subject: [PATCH 09/22] fix --- src/peft/tuners/lora/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 5c920be1aa..a01e2d50c7 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -66,7 +66,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): # Megatron ColumnParallelLinear,RowParallelLinear in_features, out_features = base_layer.input_size, base_layer.output_size - elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": + elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "AwqLoraLinear": # Awq layers in_features, out_features = base_layer.in_features, base_layer.out_features else: From 46846997b57b3ebcacf3e95feb7293efc83dfa98 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Feb 2024 01:38:44 +0100 Subject: [PATCH 10/22] change to peft internal testing --- tests/test_gpu_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 2b922ef6eb..b32c4d55af 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1389,7 +1389,7 @@ class PeftAwqGPUTests(unittest.TestCase): """ def setUp(self): - self.causal_lm_model_id = "ybelkada/opt-125m-awq" + self.causal_lm_model_id = "peft-internal-testing/opt-125m-awq" self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) def tearDown(self): From fcd51b99d20939eee6e72362c4908a2a086d1393 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Feb 2024 01:45:56 +0100 Subject: [PATCH 11/22] fix --- src/peft/import_utils.py | 22 +++++++++++----------- src/peft/tuners/lora/awq.py | 4 ++-- src/peft/tuners/lora/layer.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index e1bff6116f..4f0e31544e 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -67,15 +67,15 @@ def is_torch_tpu_available(check_device=True): def is_auto_awq_available(): if importlib.util.find_spec("awq") is not None: - # TODO: change it to 0.2.0 - # AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.7") - # version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) - # if AUTOAWQ_MINIMUM_VERSION >= version_autoawq: - # return True - # else: - # raise ImportError( - # f"Found an incompatible version of auto-gptq. Found version {version_autoawq}, " - # f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported" - # ) - return True + AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.7") + version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) + + if AUTOAWQ_MINIMUM_VERSION <= version_autoawq: + return True + else: + raise ImportError( + f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " + f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." + ) + return False diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index d40473346e..72168b7db2 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -22,7 +22,7 @@ if is_auto_awq_available(): - from awq.modules.linear import WQLinear_GEMM as AWQ_WQLinear_GEMM + from awq.modules.linear import WQLinear_GEMM class AwqLoraLinear(torch.nn.Module, LoraLayer): @@ -90,7 +90,7 @@ def dispatch_awq( else: target_base_layer = target - if is_auto_awq_available() and isinstance(target_base_layer, AWQ_WQLinear_GEMM): + if is_auto_awq_available() and isinstance(target_base_layer, WQLinear_GEMM): new_module = AwqLoraLinear(target, adapter_name, **kwargs) target.qweight = target_base_layer.qweight diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index a01e2d50c7..5c920be1aa 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -66,7 +66,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): # Megatron ColumnParallelLinear,RowParallelLinear in_features, out_features = base_layer.input_size, base_layer.output_size - elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "AwqLoraLinear": + elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": # Awq layers in_features, out_features = base_layer.in_features, base_layer.out_features else: From 87b677fba0235b79ab0bdd1fbdf5559226e99c3b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Feb 2024 01:49:05 +0100 Subject: [PATCH 12/22] fix --- src/peft/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 4f0e31544e..c1d1edfd59 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -67,7 +67,7 @@ def is_torch_tpu_available(check_device=True): def is_auto_awq_available(): if importlib.util.find_spec("awq") is not None: - + # TODO: to change it to 0.2.0 after the AutoAWQ release AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.7") version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) From 4f22260749d77393a121f2c1f477690dee092ea2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 14 Feb 2024 01:51:47 +0100 Subject: [PATCH 13/22] add multi-GPU tests --- tests/test_gpu_examples.py | 63 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index b32c4d55af..701e5ac440 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1463,3 +1463,66 @@ def test_causal_lm_training_awq(self): # assert loss is not None self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + @pytest.mark.multi_gpu_tests + @require_torch_multi_gpu + def test_causal_lm_training_multi_gpu(self): + r""" + Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set + correctly. + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + torch_dtype=torch.float16, + device_map="auto", + quantization_config=self.quantization_config, + ) + + self.assertEqual(set(model.hf_device_map.values()), set(range(torch.cuda.device_count()))) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) + self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)) + + # assert loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) From 1d52a4530f8b68d1574358d125323c56da273dfc Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:07:31 +0000 Subject: [PATCH 14/22] add to dockerfile --- docker/peft-gpu/Dockerfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index 925ade2e5a..a5b33d2658 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -40,6 +40,9 @@ SHELL ["/bin/bash", "-c"] RUN source activate peft && \ python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq +# Add autoawq for quantization testing +RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl + # Install apt libs RUN apt-get update && \ apt-get install -y curl git wget && \ From 07c048658c5a1525cf5af8f06800e9341d97043a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:08:26 +0000 Subject: [PATCH 15/22] fix todo --- src/peft/import_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index c1d1edfd59..8d8240a269 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -67,8 +67,7 @@ def is_torch_tpu_available(check_device=True): def is_auto_awq_available(): if importlib.util.find_spec("awq") is not None: - # TODO: to change it to 0.2.0 after the AutoAWQ release - AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.7") + AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) if AUTOAWQ_MINIMUM_VERSION <= version_autoawq: From 2f93c823a0900f7faa4dd774fa94eecc47896879 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:10:27 +0000 Subject: [PATCH 16/22] raise error only at the dispatch level --- src/peft/import_utils.py | 13 +------------ src/peft/tuners/lora/awq.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 8d8240a269..b3ff72bc05 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -66,15 +66,4 @@ def is_torch_tpu_available(check_device=True): def is_auto_awq_available(): - if importlib.util.find_spec("awq") is not None: - AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") - version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) - - if AUTOAWQ_MINIMUM_VERSION <= version_autoawq: - return True - else: - raise ImportError( - f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " - f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." - ) - return False + return importlib.util.find_spec("awq") is not None diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index 72168b7db2..b3f5bf3978 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import importlib.metadata as importlib_metadata from typing import Any, Optional +import packaging.version import torch from peft.import_utils import is_auto_awq_available @@ -91,6 +92,16 @@ def dispatch_awq( target_base_layer = target if is_auto_awq_available() and isinstance(target_base_layer, WQLinear_GEMM): + # Raise the error only at the dispatch level + AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") + version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) + + if AUTOAWQ_MINIMUM_VERSION > version_autoawq: + raise ImportError( + f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " + f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." + ) + new_module = AwqLoraLinear(target, adapter_name, **kwargs) target.qweight = target_base_layer.qweight From ec37ff41848f4c38f2959c2b04e49817e1a8c451 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:17:36 +0000 Subject: [PATCH 17/22] quality --- tests/test_gpu_examples.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 701e5ac440..ec50409d99 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1458,11 +1458,11 @@ def test_causal_lm_training_awq(self): model.cpu().save_pretrained(tmp_dir) - self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) - self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)) + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) # assert loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None @pytest.mark.multi_gpu_tests @require_torch_multi_gpu @@ -1477,7 +1477,6 @@ def test_causal_lm_training_multi_gpu(self): self.causal_lm_model_id, torch_dtype=torch.float16, device_map="auto", - quantization_config=self.quantization_config, ) self.assertEqual(set(model.hf_device_map.values()), set(range(torch.cuda.device_count()))) @@ -1521,8 +1520,8 @@ def test_causal_lm_training_multi_gpu(self): model.cpu().save_pretrained(tmp_dir) - self.assertTrue("adapter_config.json" in os.listdir(tmp_dir)) - self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)) + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) # assert loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None From ec422d1edbbfeae6d42a596c9040ab184df0fe36 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:26:38 +0000 Subject: [PATCH 18/22] fix test --- docker/peft-gpu/Dockerfile | 1 + tests/test_gpu_examples.py | 6 +----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index a5b33d2658..c31654febc 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -42,6 +42,7 @@ RUN source activate peft && \ # Add autoawq for quantization testing RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl +RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4+cu118-cp38-cp38-linux_x86_64.whl # Install apt libs RUN apt-get update && \ diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index ec50409d99..8f7a04d18e 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1417,7 +1417,6 @@ def test_causal_lm_training_awq(self): with tempfile.TemporaryDirectory() as tmp_dir: model = AutoModelForCausalLM.from_pretrained( self.causal_lm_model_id, - torch_dtype=torch.float16, device_map="auto", ) @@ -1447,7 +1446,6 @@ def test_causal_lm_training_awq(self): warmup_steps=2, max_steps=3, learning_rate=2e-4, - fp16=True, logging_steps=1, output_dir=tmp_dir, ), @@ -1475,11 +1473,10 @@ def test_causal_lm_training_multi_gpu(self): with tempfile.TemporaryDirectory() as tmp_dir: model = AutoModelForCausalLM.from_pretrained( self.causal_lm_model_id, - torch_dtype=torch.float16, device_map="auto", ) - self.assertEqual(set(model.hf_device_map.values()), set(range(torch.cuda.device_count()))) + assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count())) model = prepare_model_for_kbit_training(model) @@ -1509,7 +1506,6 @@ def test_causal_lm_training_multi_gpu(self): warmup_steps=2, max_steps=3, learning_rate=2e-4, - fp16=True, logging_steps=1, output_dir=tmp_dir, ), From 198f5641567e883b9c8ee736160c9c57707344f3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:28:48 +0000 Subject: [PATCH 19/22] fix dockerfile --- docker/peft-gpu/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index c31654febc..9b2bc4f1b5 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -41,8 +41,8 @@ RUN source activate peft && \ python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq # Add autoawq for quantization testing -RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl -RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4+cu118-cp38-cp38-linux_x86_64.whl +RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0-cp38-cp38-linux_x86_64.whl +RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl # Install apt libs RUN apt-get update && \ From f5123b57cf31e201ece0e9f4de52544b07e2429a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:37:55 +0000 Subject: [PATCH 20/22] fix --- tests/test_gpu_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 8f7a04d18e..9537be20cc 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1405,7 +1405,7 @@ def _check_inference_finite(self, model, batch): training = model.training model.eval() output = model(**batch.to(model.device)) - self.assertTrue(torch.isfinite(output.logits).all()) + assert torch.isfinite(output.logits).all() model.train(training) @pytest.mark.single_gpu_tests From 33d8f11ec6b6350e24f1642f8b004900e4f32765 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 16 Feb 2024 01:41:51 +0000 Subject: [PATCH 21/22] fix --- docker/peft-gpu/Dockerfile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index 9b2bc4f1b5..9a522c1af8 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -41,8 +41,10 @@ RUN source activate peft && \ python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq # Add autoawq for quantization testing -RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0-cp38-cp38-linux_x86_64.whl -RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl +RUN source activate peft && \ + python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0-cp38-cp38-linux_x86_64.whl +RUN source activate peft && \ + python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl # Install apt libs RUN apt-get update && \ From 47f006d04f40bebb828153c41ebcd674f3381515 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 19 Feb 2024 00:12:58 +0000 Subject: [PATCH 22/22] update dockerfile and tests --- docker/peft-gpu/Dockerfile | 2 +- tests/test_gpu_examples.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index 9a522c1af8..5e86834d53 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -42,7 +42,7 @@ RUN source activate peft && \ # Add autoawq for quantization testing RUN source activate peft && \ - python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0-cp38-cp38-linux_x86_64.whl + python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.1/autoawq-0.2.1-cp38-cp38-linux_x86_64.whl RUN source activate peft && \ python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 9537be20cc..943f3f54de 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1448,6 +1448,7 @@ def test_causal_lm_training_awq(self): learning_rate=2e-4, logging_steps=1, output_dir=tmp_dir, + fp16=True, ), data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), )