Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: add awq suppot in PEFT #1399

Merged
merged 24 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,19 @@ def is_torch_tpu_available(check_device=True):
return False
return True
return False


def is_auto_awq_available():
if importlib.util.find_spec("awq") is not None:
# TODO: change it to 0.2.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not quite clear to me: change it when?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'll remove it

# AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.1.7")
# version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))
# if AUTOAWQ_MINIMUM_VERSION >= version_autoawq:
# return True
# else:
# raise ImportError(
# f"Found an incompatible version of auto-gptq. Found version {version_autoawq}, "
# f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported"
# )

return True
98 changes: 98 additions & 0 deletions src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# coding=utf-8
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import torch

from peft.import_utils import is_auto_awq_available
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer


if is_auto_awq_available():
from awq.modules.linear import WQLinear_GEMM as AWQ_WQLinear_GEMM


class WQLinear_GEMM(torch.nn.Module, LoraLayer):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer)

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)

if self.disable_adapters:
return result

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result += output
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep


def dispatch_awq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if isinstance(target_base_layer, AWQ_WQLinear_GEMM):
new_module = WQLinear_GEMM(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

return new_module
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = base_layer.input_size, base_layer.output_size
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand Down
11 changes: 7 additions & 4 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_quantization_config,
)

from .awq import dispatch_awq
from .config import LoraConfig
from .gptq import dispatch_gptq
from .layer import Conv2d, LoraLayer, dispatch_default
Expand Down Expand Up @@ -156,9 +157,11 @@ def _create_and_replace(
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}

quantization_config = get_quantization_config(self.model, method="gptq")
if quantization_config is not None:
kwargs["gptq_quantization_config"] = quantization_config
quant_methods = ["gptq", "awq"]
for quant_method in quant_methods:
quantization_config = get_quantization_config(self.model, method=quant_method)
if quantization_config is not None:
kwargs[f"{quant_method}_quantization_config"] = quantization_config

# note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
from peft.tuners.adalora import AdaLoraLayer
Expand Down Expand Up @@ -244,7 +247,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):

dispatchers.append(dispatch_bnb_4bit)

dispatchers.extend([dispatch_gptq, dispatch_megatron, dispatch_default])
dispatchers.extend([dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])

new_module = None
for dispatcher in dispatchers:
Expand Down
86 changes: 86 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from peft.utils import SAFETENSORS_WEIGHTS_NAME

from .testing_utils import (
require_auto_awq,
require_auto_gptq,
require_bitsandbytes,
require_optimum,
Expand Down Expand Up @@ -1378,3 +1379,88 @@ def test_model_loaded_in_float16_working(self):
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
trainer.train()


@require_torch_gpu
@require_auto_awq
class PeftAwqGPUTests(unittest.TestCase):
r"""
Awq + peft tests
"""

def setUp(self):

self.causal_lm_model_id = "ybelkada/opt-125m-awq"
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)

def tearDown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
torch.cuda.empty_cache()

def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
self.assertTrue(torch.isfinite(output.logits).all())
model.train(training)

@pytest.mark.single_gpu_tests
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
def test_causal_lm_training_awq(self):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
r"""
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
)

model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)

# TODO: deal correctly with this case in transformers
model._is_quantized_training_enabled = True

trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()

model.cpu().save_pretrained(tmp_dir)

self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))

# assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
9 changes: 8 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
import torch

from peft.import_utils import is_auto_gptq_available, is_optimum_available
from peft.import_utils import is_auto_awq_available, is_auto_gptq_available, is_optimum_available


def require_torch_gpu(test_case):
Expand Down Expand Up @@ -62,6 +62,13 @@ def require_auto_gptq(test_case):
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)


def require_auto_awq(test_case):
"""
Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed.
"""
return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case)


def require_optimum(test_case):
"""
Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed.
Expand Down
Loading