Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoGPTQ's cpu kernel. #245

Merged
merged 13 commits into from
Aug 25, 2023
Merged

Add AutoGPTQ's cpu kernel. #245

merged 13 commits into from
Aug 25, 2023

Conversation

qwopqwop200
Copy link
Collaborator

@qwopqwop200 qwopqwop200 commented Aug 10, 2023

Currently only supports 2,4 bits, desc_act does not.
Enabling use_qigen overrides all settings related to cuda.

import os

from transformers import AutoTokenizer, TextGenerationPipeline
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import numpy as np
import torch
import torch.nn as nn

pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = "opt-125m-4bit-128g"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, use_qigen=True)
print(tokenizer.decode(model.generate(**tokenizer("test is", return_tensors="pt"),max_new_tokens=5)[0]))
opt-125m Bits group-size Wikitext2
cuda 4 128 29.8402
QIgen 4 128 29.8416

@qwopqwop200 qwopqwop200 marked this pull request as draft August 11, 2023 04:32
Copy link
Collaborator

@PanQiWei PanQiWei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I just leave some reviews to make sure we can have a good shape for this pr, so that it can be easily extend and mantained in the future.

auto_gptq/modeling/_base.py Show resolved Hide resolved
auto_gptq/nn_modules/qlinear/qlinear_qigen.py Show resolved Hide resolved
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable
checkpoint=checkpoint,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to make_quant_cpu first and then load checkpoint. And of course this requires a futher modification in make_quant_cpu(), will describe it on there.

qweights=checkpoint[name1 + '.qweight'].contiguous(),
zeros=checkpoint[name1 + '.qzeros'],
scales=checkpoint[name1 + '.scales'].float(),
bias = checkpoint[name1 + '.bias'].float() if name1 + '.bias' in checkpoint else None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I'm not fond of the way it make quant here, basically it load checkpoint directly when init QuantLinear, which breaks the API format to be different with other QuantLinear, and may cause difficulty of maintainance in the future.

My suggestions is only replace module in make_quant functions and load state dicts after this function is called in from_quantized()

class QuantLinear(nn.Module):
QUANT_TYPE = "qigen"

def __init__(self, bits=4, group_size=-1, N=0, M=0, qweights=None, zeros=None, scales=None, bias=None, hint=1, p=8, l1=2**18):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least those five arguments' name and order should be same with other QuantLinear: bits, group_size, infeatures, outfeatures, bias

@qwopqwop200 qwopqwop200 marked this pull request as ready for review August 17, 2023 06:23
@qwopqwop200
Copy link
Collaborator Author

My code is complete. Please review again.

@fxmarty
Copy link
Collaborator

fxmarty commented Aug 18, 2023

@qwopqwop200 Did not look at the code but given that CPUs usually do not support fp16 compute, is the compute rather done in fp32? With dequantization e.g. int4 -> fp32?

@qwopqwop200
Copy link
Collaborator Author

Probably so. But I didn't design this kernel, so I'm not sure how it works.

@fxmarty
Copy link
Collaborator

fxmarty commented Aug 21, 2023

Where does it come from?

@qwopqwop200
Copy link
Collaborator Author

Where does it come from?

https://github.com/IST-DASLab/QIGen/tree/master

@qwopqwop200 qwopqwop200 merged commit 6a9d80e into AutoGPTQ:main Aug 25, 2023
@qwopqwop200
Copy link
Collaborator Author

oh sorry I merged them by mistake. I've checked that it's working now, so there shouldn't be any problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants