Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SUPPORTS_AUTO_PADDING property to QuantLinear #799

Merged
merged 9 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class BaseQuantLinear(nn.Module):
SUPPORTS_SYM: List[bool] = None
SUPPORTS_SHARDS: bool = None
SUPPORTS_TRAINING: bool = None
SUPPORTS_AUTO_PADDING: bool = None
SUPPORTS_IN_FEATURES_DIVISIBLE_BY: List[int] = None
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY: List[int] = None

Expand All @@ -30,9 +31,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat

@classmethod
# custom quant linear class can override this and add custom checks
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
validate, err = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)
validate, err = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym,
infeatures=infeatures, outfeatures=outfeatures, dynamic=dynamic,
device=device, trainable=trainable)
return validate, err

@classmethod
Expand Down Expand Up @@ -138,12 +142,16 @@ def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynami
if not validate:
err = f"{cls}: `infeatures` must be divisible by {cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)

validate = infeatures % group_size == 0 or cls.SUPPORTS_AUTO_PADDING
if not validate:
err = f"{cls}: `infeatures` must be divisible by `group_size: {group_size}`."
return False, NotImplementedError(err)
if outfeatures is not None:
validate = all(outfeatures % out_fea == 0 for out_fea in cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY)
if not validate:
err = f"{cls}: `outfeatures` must be divisible by {cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)

return True, None

@classmethod
Expand Down
4 changes: 3 additions & 1 deletion gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class BitBLASQuantLinear(BaseQuantLinear):
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = False
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [16]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [16]

Expand Down Expand Up @@ -135,7 +136,8 @@ def __init__(
self.reset_parameters()

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
if not BITBLAS_AVAILABLE:
return False, ValueError(BITBLAS_INSTALL_HINT)
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/dynamic_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DynamicCudaQuantLinear(TorchQuantLinear):
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = False # TODO fix this
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [64]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64]

Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ExllamaQuantLinear(BaseQuantLinear):
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = False
SUPPORTS_AUTO_PADDING = True
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]

Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class ExllamaV2QuantLinear(BaseQuantLinear):
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = False
SUPPORTS_AUTO_PADDING = True
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]

Expand Down
4 changes: 3 additions & 1 deletion gptqmodel/nn_modules/qlinear/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class IPEXQuantLinear(BaseQuantLinear):
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = True
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]

Expand Down Expand Up @@ -129,7 +130,8 @@ def __init__(
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
if not IPEX_AVAILABLE:
return False, IPEX_ERROR_LOG
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class MarlinQuantLinear(BaseQuantLinear):
SUPPORTS_SYM = [True]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = False
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64]

Expand Down
1 change: 1 addition & 0 deletions gptqmodel/nn_modules/qlinear/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TorchQuantLinear(BaseQuantLinear):
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = True
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]

Expand Down
4 changes: 3 additions & 1 deletion gptqmodel/nn_modules/qlinear/tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = True
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]

Expand Down Expand Up @@ -93,7 +94,8 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
self.bias = None

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
if not TRITON_AVAILABLE:
return False, ValueError(TRITON_INSTALL_HINT)
Expand Down
13 changes: 11 additions & 2 deletions gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
})

format_dict = {
FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH],
FORMAT.GPTQ_V2: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH],
FORMAT.GPTQ: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH],
FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH],
FORMAT.MARLIN: [BACKEND.MARLIN],
FORMAT.BITBLAS: [BACKEND.BITBLAS],
FORMAT.IPEX: [BACKEND.IPEX],
Expand Down Expand Up @@ -74,6 +74,7 @@ def hf_select_quant_linear(
device=device,
format=FORMAT.GPTQ,
pack=True,
allow_marlin=False, # TODO: remove this after marlin padding is fixed
dynamic=None,
)

Expand All @@ -88,6 +89,7 @@ def select_quant_linear(
backend: BACKEND = BACKEND.AUTO,
format: FORMAT = FORMAT.GPTQ,
pack: bool = False,
allow_marlin: bool = True, # TODO: remove this after marlin padding is fixed
dynamic=None,
) -> Type[BaseQuantLinear]:
if not torch.cuda.is_available():
Expand All @@ -103,6 +105,13 @@ def select_quant_linear(
trainable = backend == BACKEND.AUTO_TRAINABLE

allow_backends = format_dict[format]

# TODO: fix marlin padding
# Since Marlin does not support padding in_features and out_features, Marlin is not allowed for hf_select_quant_linear scenarios
# for gptq internal use, allow_marlin is set to True
if format in [FORMAT.GPTQ, FORMAT.GPTQ_V2] and allow_marlin:
allow_backends = [BACKEND.MARLIN] + allow_backends

allow_quant_linears = backend_dict
err = None
global message_logged
Expand Down
23 changes: 20 additions & 3 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,23 @@ def make_quant(
dynamic=dynamic,
)

if pack:
reserve_quant_linear = ExllamaQuantLinear
else:
reserve_quant_linear = ExllamaV2QuantLinear

# TODO, we need fix here. if select other linears
for linear in list(dict.fromkeys([QuantLinear, TorchQuantLinear])):
for linear in list(dict.fromkeys([QuantLinear, reserve_quant_linear, TorchQuantLinear])):
try:
if linear is not QuantLinear:
logger.info(f"Use {QuantLinear} failed, try to use {linear} instead.")

result = create_quant_layer(linear, bits, desc_act, dynamic, group_size, module, names, sym)
return result
except NotImplementedError as e:
# only fallback to other quant linears when backend is auto.
if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]:
raise e
continue

raise ValueError("no support quant linear was found for this module.")

Expand All @@ -159,9 +166,19 @@ def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module,
elif isinstance(submodule, transformers.pytorch_utils.Conv1D):
in_features = submodule.weight.shape[0]
out_features = submodule.weight.shape[1]
elif isinstance(submodule, BaseQuantLinear):
# if submodule is already a quant layer, we need to get in_features and out_features from the submodule
in_features = submodule.infeatures
out_features = submodule.outfeatures
else:
raise NotImplementedError(f"Unsupported module {submodule}")

# check in_features and out_features validate
_, err = QuantLinear.validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym,
infeatures=in_features, outfeatures=out_features)
if err is not None:
raise err

bias = submodule.bias is not None

d_bits = bits
Expand All @@ -183,7 +200,7 @@ def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module,
infeatures=in_features,
outfeatures=out_features,
bias=bias,
weight_dtype=submodule.weight.dtype,
weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype,
)
new_layer.device = ori_layer_device
recurse_setattr(module, name, new_layer.to(ori_layer_device))
Expand Down