Skip to content

Commit

Permalink
FEAT: add awq suppot in PEFT (huggingface#1399)
Browse files Browse the repository at this point in the history
* add awq suppot in PEFT

* fix

* fux

* Update src/peft/tuners/lora/awq.py

* style & fix tests

* forward contrib credits from PR14084

* forward contrib credits from autoawq PR

* change name

* fix

* change to peft internal testing

* fix

* fix

* add multi-GPU tests

* add to dockerfile

* fix todo

* raise error only at the dispatch level

* quality

* fix test

* fix dockerfile

* fix

* fix

* update dockerfile and tests

---------

Co-authored-by: s4rduk4r <s4rduk4r@users.noreply.github.com>
  • Loading branch information
2 people authored and BenjaminBossan committed Mar 14, 2024
1 parent 972ca0a commit ab7c50a
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 5 deletions.
6 changes: 6 additions & 0 deletions docker/peft-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ SHELL ["/bin/bash", "-c"]
RUN source activate peft && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq

# Add autoawq for quantization testing
RUN source activate peft && \
python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.1/autoawq-0.2.1-cp38-cp38-linux_x86_64.whl
RUN source activate peft && \
python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl

# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
Expand Down
4 changes: 4 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ def is_torch_tpu_available(check_device=True):
return False
return True
return False


def is_auto_awq_available():
return importlib.util.find_spec("awq") is not None
108 changes: 108 additions & 0 deletions src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata as importlib_metadata
from typing import Any, Optional

import packaging.version
import torch

from peft.import_utils import is_auto_awq_available
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer


if is_auto_awq_available():
from awq.modules.linear import WQLinear_GEMM


class AwqLoraLinear(torch.nn.Module, LoraLayer):
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer)

# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)

if self.disable_adapters:
return result

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling
result = result + output
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep


def dispatch_awq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

if is_auto_awq_available() and isinstance(target_base_layer, WQLinear_GEMM):
# Raise the error only at the dispatch level
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))

if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
raise ImportError(
f"Found an incompatible version of auto-awq. Found version {version_autoawq}, "
f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT."
)

new_module = AwqLoraLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

return new_module
3 changes: 3 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = base_layer.input_size, base_layer.output_size
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")

Expand Down
11 changes: 7 additions & 4 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties

from .awq import dispatch_awq
from .config import LoraConfig
from .gptq import dispatch_gptq
from .layer import Conv2d, LoraLayer, dispatch_default
Expand Down Expand Up @@ -156,9 +157,11 @@ def _create_and_replace(
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}

quantization_config = get_quantization_config(self.model, method="gptq")
if quantization_config is not None:
kwargs["gptq_quantization_config"] = quantization_config
quant_methods = ["gptq", "awq"]
for quant_method in quant_methods:
quantization_config = get_quantization_config(self.model, method=quant_method)
if quantization_config is not None:
kwargs[f"{quant_method}_quantization_config"] = quantization_config

# note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
from peft.tuners.adalora import AdaLoraLayer
Expand Down Expand Up @@ -244,7 +247,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):

dispatchers.append(dispatch_bnb_4bit)

dispatchers.extend([dispatch_gptq, dispatch_megatron, dispatch_default])
dispatchers.extend([dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])

new_module = None
for dispatcher in dispatchers:
Expand Down
144 changes: 144 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from peft.utils import SAFETENSORS_WEIGHTS_NAME

from .testing_utils import (
require_auto_awq,
require_auto_gptq,
require_bitsandbytes,
require_optimum,
Expand Down Expand Up @@ -1380,3 +1381,146 @@ def test_model_loaded_in_float16_working(self):
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
trainer.train()


@require_torch_gpu
@require_auto_awq
class PeftAwqGPUTests(unittest.TestCase):
r"""
Awq + peft tests
"""

def setUp(self):
self.causal_lm_model_id = "peft-internal-testing/opt-125m-awq"
self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)

def tearDown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
torch.cuda.empty_cache()

def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
assert torch.isfinite(output.logits).all()
model.train(training)

@pytest.mark.single_gpu_tests
def test_causal_lm_training_awq(self):
r"""
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
)

model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)

# TODO: deal correctly with this case in transformers
model._is_quantized_training_enabled = True

trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
logging_steps=1,
output_dir=tmp_dir,
fp16=True,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()

model.cpu().save_pretrained(tmp_dir)

assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

@pytest.mark.multi_gpu_tests
@require_torch_multi_gpu
def test_causal_lm_training_multi_gpu(self):
r"""
Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
correctly.
"""

with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
)

assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count()))

model = prepare_model_for_kbit_training(model)

setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)

config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)

trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()

model.cpu().save_pretrained(tmp_dir)

assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
9 changes: 8 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import torch

from peft.import_utils import is_auto_gptq_available, is_optimum_available
from peft.import_utils import is_auto_awq_available, is_auto_gptq_available, is_optimum_available


def require_torch_gpu(test_case):
Expand Down Expand Up @@ -61,6 +61,13 @@ def require_auto_gptq(test_case):
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)


def require_auto_awq(test_case):
"""
Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed.
"""
return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case)


def require_optimum(test_case):
"""
Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed.
Expand Down

0 comments on commit ab7c50a

Please sign in to comment.