Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT compatible GEMM #324

Merged
merged 5 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 105 additions & 27 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from torch.autograd import Function
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm

Expand All @@ -10,9 +11,94 @@
except:
AWQ_INSTALLED = False

# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0
):
# The forward pass can use ctx.
ctx.save_for_backward(x, qweight, qzeros, scales, bias)
ctx.out_features = out_features

out_shape = x.shape[:-1] + (out_features, )
x = x.to(torch.float16)

if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
0,
0,
0,
False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
qweight,
scales,
qzeros,
8
)
else:
out = dequantize_gemm(
qweight,
qzeros,
scales,
w_bit,
group_size
)
out = torch.matmul(x, out)

out = out + bias if bias is not None else out
out = out.reshape(out_shape)

# always want 3D tensor if tensor is 2D
if len(out.shape) == 2:
out = out.unsqueeze(0)

return out

@staticmethod
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors

weights = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
1,
0,
0,
False
)

if ctx.needs_input_grad[0]:
# 2D matrix multiplication, unsqueeze to 3D
grad_input = grad_output.squeeze(0).mm(
weights.transpose(0, 1)
).unsqueeze(0)

return grad_input, None, None, None, None, None, None, None


class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False):
super().__init__()

if w_bit not in [4]:
Expand All @@ -22,6 +108,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.training = training

# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
Expand Down Expand Up @@ -145,45 +232,36 @@ def from_linear(

return awq_linear

@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)

input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()

if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
self.qweight,
self.scales,
self.qzeros,
0,
0,
0,
False,
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
8,
)
else:
out = dequantize_gemm(
if self.training:
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)
out = torch.matmul(x, out)
else:
with torch.no_grad():
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)

if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
Expand Down
76 changes: 76 additions & 0 deletions examples/awq_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import datasets
from awq import AutoAWQForCausalLM
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import get_peft_model, LoraConfig, TaskType

def prepare_split(tokenizer):
data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train")
prompt_template = "<s>[INST] {system} {prompt} [/INST] {output}</s>"

def format_prompt(x):
return prompt_template.format(
system="",
prompt=x["instruction"],
output=x["output"]
)

data = data.map(
lambda x: {"text": format_prompt(x)},
).select_columns(["text"])
data = data.map(lambda x: tokenizer(x["text"]), batched=True)

return data

model_path = "ybelkada/opt-125m-awq"

# Load model
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

# Prepare data
data_train = prepare_split(tokenizer)

# Config Lora
lora_config = LoraConfig(
r=4,
lora_alpha=8,
lora_dropout=0.5,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)

model = get_peft_model(model.model, lora_config)

model.print_trainable_parameters()

training_arguments = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=1,
optim="adamw_torch",
num_train_epochs=1,
learning_rate=1e-4,
# fp16=True,
evaluation_strategy="no",
save_strategy="epoch",
save_steps=100,
logging_steps=50,
eval_steps=None,
load_best_model_at_end=False
)

trainer = Trainer(
model=model,
train_dataset=data_train,
args=training_arguments,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()
trainer.save_model("output")