From 162c294ddadbcc8c8609c887f19deaff6608a2f9 Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 26 Jan 2024 16:53:43 +0100 Subject: [PATCH 1/3] Initial --- awq/modules/linear/gemm.py | 117 +++++++++++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 24 deletions(-) diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index 6128adad..e270653b 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from torch.autograd import Function from awq.utils.utils import get_best_device from awq.utils.packing_utils import dequantize_gemm @@ -9,9 +10,81 @@ except: AWQ_INSTALLED = False +# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev +class WQLinearMMFunction(Function): + @staticmethod + # ctx is the first argument to forward + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0 + ): + # The forward pass can use ctx. + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features, ) + x = x.to(torch.float16) + + if AWQ_INSTALLED: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_ext.dequantize_weights_cuda( + qweight, + scales, + qzeros, + 0, + 0, + 0, + False + ) + out = torch.matmul(x, out) + else: + out = awq_ext.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), + qweight, + scales, + qzeros, + 8 + ) + else: + out = dequantize_gemm( + qweight, + qzeros, + scales, + w_bit, + group_size + ) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + + return out.reshape(out_shape) + + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + out_features = ctx.out_features + grad_input = grad_weight = grad_zeros = grad_scales = grad_bias = grad_out_features = None + weight = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 1, 0, 0, False) + + if ctx.needs_input_grad[0]: + grad_input = grad_output[0].mm(weight.t()).unsqueeze(0) + if bias is not None and ctx.needs_input_grad[4]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_zeros, grad_scales, grad_bias, grad_out_features + class WQLinear_GEMM(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): super().__init__() if w_bit not in [4]: @@ -21,6 +94,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features + self.training = training # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 @@ -144,7 +218,6 @@ def from_linear( return awq_linear - @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) @@ -152,33 +225,29 @@ def forward(self, x): if input_dtype != torch.float16: x = x.half() - if AWQ_INSTALLED: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_ext.dequantize_weights_cuda( - self.qweight, - self.scales, - self.qzeros, - 0, - 0, - 0, - False - ) - out = torch.matmul(x, out) - else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 - ) - else: - out = dequantize_gemm( + if self.training: + out = WQLinearMMFunction.apply( + x, self.qweight, self.qzeros, self.scales, self.w_bit, - self.group_size + self.group_size, + self.bias, + self.out_features, ) - out = torch.matmul(x, out) + else: + with torch.no_grad(): + out = WQLinearMMFunction.apply( + x, + self.qweight, + self.qzeros, + self.scales, + self.w_bit, + self.group_size, + self.bias, + self.out_features, + ) if input_dtype != torch.float16: out = out.to(dtype=input_dtype) From 8d3dd561c36b6846a7f65041288147c87ae34c16 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 26 Jan 2024 18:46:20 +0000 Subject: [PATCH 2/3] Fix PEFT training --- awq/modules/linear/gemm.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index e270653b..5726c1c4 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -65,22 +65,35 @@ def forward( out = torch.matmul(x, out) out = out + bias if bias is not None else out + out = out.reshape(out_shape) - return out.reshape(out_shape) + # always want 3D tensor if tensor is 2D + if len(out.shape) == 2: + out = out.unsqueeze(0) + + return out @staticmethod def backward(ctx, grad_output): input, qweight, qzeros, scales, bias = ctx.saved_tensors - out_features = ctx.out_features - grad_input = grad_weight = grad_zeros = grad_scales = grad_bias = grad_out_features = None - weight = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 1, 0, 0, False) + + weights = awq_ext.dequantize_weights_cuda( + qweight, + scales, + qzeros, + 1, + 0, + 0, + False + ) if ctx.needs_input_grad[0]: - grad_input = grad_output[0].mm(weight.t()).unsqueeze(0) - if bias is not None and ctx.needs_input_grad[4]: - grad_bias = grad_output.sum(0) + # 2D matrix multiplication, unsqueeze to 3D + grad_input = grad_output.squeeze(0).mm( + weights.transpose(0, 1) + ).unsqueeze(0) - return grad_input, grad_weight, grad_zeros, grad_scales, grad_bias, grad_out_features + return grad_input, None, None, None, None, None, None, None class WQLinear_GEMM(nn.Module): From 51a1a8a41d6c8b9330d278602ea083784266e5b3 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 26 Jan 2024 18:46:37 +0000 Subject: [PATCH 3/3] Re-add training example --- examples/awq_train.py | 76 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 examples/awq_train.py diff --git a/examples/awq_train.py b/examples/awq_train.py new file mode 100644 index 00000000..5e8fd0f5 --- /dev/null +++ b/examples/awq_train.py @@ -0,0 +1,76 @@ +import datasets +from awq import AutoAWQForCausalLM +from transformers import ( + AutoTokenizer, + TrainingArguments, + Trainer, + DataCollatorForLanguageModeling +) +from peft import get_peft_model, LoraConfig, TaskType + +def prepare_split(tokenizer): + data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train") + prompt_template = "[INST] {system} {prompt} [/INST] {output}" + + def format_prompt(x): + return prompt_template.format( + system="", + prompt=x["instruction"], + output=x["output"] + ) + + data = data.map( + lambda x: {"text": format_prompt(x)}, + ).select_columns(["text"]) + data = data.map(lambda x: tokenizer(x["text"]), batched=True) + + return data + +model_path = "ybelkada/opt-125m-awq" + +# Load model +model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False) +tokenizer = AutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = tokenizer.eos_token + +# Prepare data +data_train = prepare_split(tokenizer) + +# Config Lora +lora_config = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.5, + bias="none", + task_type=TaskType.CAUSAL_LM, + inference_mode=False +) + +model = get_peft_model(model.model, lora_config) + +model.print_trainable_parameters() + +training_arguments = TrainingArguments( + output_dir="./output", + per_device_train_batch_size=1, + optim="adamw_torch", + num_train_epochs=1, + learning_rate=1e-4, + # fp16=True, + evaluation_strategy="no", + save_strategy="epoch", + save_steps=100, + logging_steps=50, + eval_steps=None, + load_best_model_at_end=False +) + +trainer = Trainer( + model=model, + train_dataset=data_train, + args=training_arguments, + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), +) + +trainer.train() +trainer.save_model("output") \ No newline at end of file