Skip to content

Commit

Permalink
exclude marlin & exllama on windows (#898)
Browse files Browse the repository at this point in the history
* exclude marlin & exllama on windows

* validate as false when qlinear was not installed

* fix list combination

* simplify validate args
  • Loading branch information
CSY-ModelCloud authored Dec 18, 2024
1 parent cb41aca commit 8b445a3
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 44 deletions.
6 changes: 2 additions & 4 deletions gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,10 @@ def __init__(
self.reset_parameters()

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if not BITBLAS_AVAILABLE:
return False, ValueError(BITBLAS_INSTALL_HINT)
return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)
return cls._validate(**args)

def _validate_parameters(
self, group_size: int, infeatures: int, outfeatures: int
Expand Down
7 changes: 7 additions & 0 deletions gptqmodel/nn_modules/qlinear/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
from logging import getLogger
from typing import Optional, Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -112,6 +113,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
else:
self.bias = None

@classmethod
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if exllama_import_exception is not None:
return False, exllama_import_exception
return cls._validate(**args)

def post_init(self):
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None
Expand Down
7 changes: 7 additions & 0 deletions gptqmodel/nn_modules/qlinear/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2

import math
from typing import Tuple, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -179,6 +180,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
else:
self.bias = None

@classmethod
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if exllama_v2_import_exception is not None:
return False, exllama_v2_import_exception
return cls._validate(**args)

def post_init(self, temp_dq):
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None
Expand Down
6 changes: 2 additions & 4 deletions gptqmodel/nn_modules/qlinear/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,13 @@ def __init__(
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if sys.platform != "linux":
return False, Exception("IPEX is only available on Linux platform.")

if not HAS_IPEX:
return False, IPEX_ERROR_LOG
return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)
return cls._validate(**args)

def post_init(self):
self.validate_device(self.qweight.device.type)
Expand Down
6 changes: 6 additions & 0 deletions gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
else:
self.bias = None

@classmethod
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if marlin_import_exception is not None:
return False, marlin_import_exception
return cls._validate(**args)

def post_init(self):
device = self.qweight.device
self.validate_device(device.type)
Expand Down
6 changes: 2 additions & 4 deletions gptqmodel/nn_modules/qlinear/tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,10 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
self.bias = None

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if not TRITON_AVAILABLE:
return False, ValueError(TRITON_INSTALL_HINT)
return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)
return cls._validate(**args)

def post_init(self):
self.validate_device(self.qweight.device.type)
Expand Down
71 changes: 39 additions & 32 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,40 +174,47 @@ def get_version_tag(is_cuda_release: bool = True) -> str:
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
),
cpp_ext.CUDAExtension(
"gptqmodel_marlin_kernels",
[
"gptqmodel_ext/marlin/marlin_cuda.cpp",
"gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
"gptqmodel_ext/marlin/marlin_repack.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
),
cpp_ext.CUDAExtension(
"gptqmodel_exllama_kernels",
[
"gptqmodel_ext/exllama/exllama_ext.cpp",
"gptqmodel_ext/exllama/cuda_buffers.cu",
"gptqmodel_ext/exllama/cuda_func/column_remap.cu",
"gptqmodel_ext/exllama/cuda_func/q4_matmul.cu",
"gptqmodel_ext/exllama/cuda_func/q4_matrix.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
),
cpp_ext.CUDAExtension(
"gptqmodel_exllamav2_kernels",
[
"gptqmodel_ext/exllamav2/ext.cpp",
"gptqmodel_ext/exllamav2/cuda/q_matrix.cu",
"gptqmodel_ext/exllamav2/cuda/q_gemm.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
)
]

if sys.platform != "win32":
extensions += [
# TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
cpp_ext.CUDAExtension(
"gptqmodel_marlin_kernels",
[
"gptqmodel_ext/marlin/marlin_cuda.cpp",
"gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
"gptqmodel_ext/marlin/marlin_repack.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
),
# TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm
cpp_ext.CUDAExtension(
"gptqmodel_exllama_kernels",
[
"gptqmodel_ext/exllama/exllama_ext.cpp",
"gptqmodel_ext/exllama/cuda_buffers.cu",
"gptqmodel_ext/exllama/cuda_func/column_remap.cu",
"gptqmodel_ext/exllama/cuda_func/q4_matmul.cu",
"gptqmodel_ext/exllama/cuda_func/q4_matrix.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
),
# TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm
cpp_ext.CUDAExtension(
"gptqmodel_exllamav2_kernels",
[
"gptqmodel_ext/exllamav2/ext.cpp",
"gptqmodel_ext/exllamav2/cuda/q_matrix.cu",
"gptqmodel_ext/exllamav2/cuda/q_gemm.cu",
],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
)
]

additional_setup_kwargs = {"ext_modules": extensions, "cmdclass": {"build_ext": cpp_ext.BuildExtension}}


Expand Down

0 comments on commit 8b445a3

Please sign in to comment.