diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 4e5535d5b..f8ea5e5b6 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -137,12 +137,10 @@ def __init__( self.reset_parameters() @classmethod - def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None, - outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[ - bool, Optional[Exception]]: + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: if not BITBLAS_AVAILABLE: return False, ValueError(BITBLAS_INSTALL_HINT) - return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable) + return cls._validate(**args) def _validate_parameters( self, group_size: int, infeatures: int, outfeatures: int diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py index e55533503..fd2684de5 100644 --- a/gptqmodel/nn_modules/qlinear/exllama.py +++ b/gptqmodel/nn_modules/qlinear/exllama.py @@ -3,6 +3,7 @@ import math from logging import getLogger +from typing import Optional, Tuple import numpy as np import torch @@ -112,6 +113,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat else: self.bias = None + @classmethod + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + if exllama_import_exception is not None: + return False, exllama_import_exception + return cls._validate(**args) + def post_init(self): self.validate_device(self.qweight.device.type) assert self.qweight.device.index is not None diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index 5d9f752f5..ac9425577 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -2,6 +2,7 @@ # Adapted from turboderp exllama: https://github.com/turboderp/exllamav2 import math +from typing import Tuple, Optional import torch import torch.nn.functional as F @@ -179,6 +180,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat else: self.bias = None + @classmethod + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + if exllama_v2_import_exception is not None: + return False, exllama_v2_import_exception + return cls._validate(**args) + def post_init(self, temp_dq): self.validate_device(self.qweight.device.type) assert self.qweight.device.index is not None diff --git a/gptqmodel/nn_modules/qlinear/ipex.py b/gptqmodel/nn_modules/qlinear/ipex.py index 009759e8e..16ef7a046 100644 --- a/gptqmodel/nn_modules/qlinear/ipex.py +++ b/gptqmodel/nn_modules/qlinear/ipex.py @@ -134,15 +134,13 @@ def __init__( self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0) @classmethod - def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None, - outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[ - bool, Optional[Exception]]: + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: if sys.platform != "linux": return False, Exception("IPEX is only available on Linux platform.") if not HAS_IPEX: return False, IPEX_ERROR_LOG - return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable) + return cls._validate(**args) def post_init(self): self.validate_device(self.qweight.device.type) diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 7866a084f..71765f8ab 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -284,6 +284,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat else: self.bias = None + @classmethod + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: + if marlin_import_exception is not None: + return False, marlin_import_exception + return cls._validate(**args) + def post_init(self): device = self.qweight.device self.validate_device(device.type) diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 00715c1d7..70c99ddbe 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -99,12 +99,10 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat self.bias = None @classmethod - def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None, - outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[ - bool, Optional[Exception]]: + def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: if not TRITON_AVAILABLE: return False, ValueError(TRITON_INSTALL_HINT) - return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable) + return cls._validate(**args) def post_init(self): self.validate_device(self.qweight.device.type) diff --git a/setup.py b/setup.py index 0a8ab3edd..dd54dc13d 100644 --- a/setup.py +++ b/setup.py @@ -174,40 +174,47 @@ def get_version_tag(is_cuda_release: bool = True) -> str: extra_link_args=extra_link_args, extra_compile_args=extra_compile_args, ), - cpp_ext.CUDAExtension( - "gptqmodel_marlin_kernels", - [ - "gptqmodel_ext/marlin/marlin_cuda.cpp", - "gptqmodel_ext/marlin/marlin_cuda_kernel.cu", - "gptqmodel_ext/marlin/marlin_repack.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ), - cpp_ext.CUDAExtension( - "gptqmodel_exllama_kernels", - [ - "gptqmodel_ext/exllama/exllama_ext.cpp", - "gptqmodel_ext/exllama/cuda_buffers.cu", - "gptqmodel_ext/exllama/cuda_func/column_remap.cu", - "gptqmodel_ext/exllama/cuda_func/q4_matmul.cu", - "gptqmodel_ext/exllama/cuda_func/q4_matrix.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ), - cpp_ext.CUDAExtension( - "gptqmodel_exllamav2_kernels", - [ - "gptqmodel_ext/exllamav2/ext.cpp", - "gptqmodel_ext/exllamav2/cuda/q_matrix.cu", - "gptqmodel_ext/exllamav2/cuda/q_gemm.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) ] + if sys.platform != "win32": + extensions += [ + # TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply + cpp_ext.CUDAExtension( + "gptqmodel_marlin_kernels", + [ + "gptqmodel_ext/marlin/marlin_cuda.cpp", + "gptqmodel_ext/marlin/marlin_cuda_kernel.cu", + "gptqmodel_ext/marlin/marlin_repack.cu", + ], + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + ), + # TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm + cpp_ext.CUDAExtension( + "gptqmodel_exllama_kernels", + [ + "gptqmodel_ext/exllama/exllama_ext.cpp", + "gptqmodel_ext/exllama/cuda_buffers.cu", + "gptqmodel_ext/exllama/cuda_func/column_remap.cu", + "gptqmodel_ext/exllama/cuda_func/q4_matmul.cu", + "gptqmodel_ext/exllama/cuda_func/q4_matrix.cu", + ], + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + ), + # TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm + cpp_ext.CUDAExtension( + "gptqmodel_exllamav2_kernels", + [ + "gptqmodel_ext/exllamav2/ext.cpp", + "gptqmodel_ext/exllamav2/cuda/q_matrix.cu", + "gptqmodel_ext/exllamav2/cuda/q_gemm.cu", + ], + extra_link_args=extra_link_args, + extra_compile_args=extra_compile_args, + ) + ] + additional_setup_kwargs = {"ext_modules": extensions, "cmdclass": {"build_ext": cpp_ext.BuildExtension}}