diff --git a/gptqmodel/models/deepseek_v2.py b/gptqmodel/models/deepseek_v2.py index d92ffbab..9f9dbb8d 100644 --- a/gptqmodel/models/deepseek_v2.py +++ b/gptqmodel/models/deepseek_v2.py @@ -39,4 +39,4 @@ class DeepSeekV2GPTQ(BaseGPTQModel): # included in layer 1-59 ["mlp.shared_experts.gate_proj", "mlp.shared_experts.up_proj"], ["mlp.shared_experts.down_proj"], - ] \ No newline at end of file + ] diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 9b5cad34..2b6df6a4 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -17,6 +17,8 @@ from ..models._const import CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS from ..nn_modules.qlinear import BaseQuantLinear +from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear +from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear as ExllamaV2QuantLinear from ..quantization import FORMAT, QuantizeConfig from .backend import Backend from .importer import select_quant_linear @@ -393,25 +395,20 @@ def simple_dispatch_model(model, device_map): return model - -# TODO: refractor. very strange post_init has to re-determine qlinear type again -# when qliear type is selected, it should auto-override the model post_init method and -# not have to go about looping over modules to match qlinear type a second time as it is -# very prone to bugs def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[int] = None): """ The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state. """ - # post init for bitblas backend. device_to_buffers_size = {} - for _, submodule in model.named_modules(): - if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "bitblas": - submodule.post_init() - + # exllama model_uses_exllama = False + # exllamav2 + fixed_bytes = {} + model_uses_exllamav2 = False + for name, submodule in model.named_modules(): - if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllama": + if isinstance(submodule, ExllamaQuantLinear): model_uses_exllama = True device = submodule.qweight.device if device not in device_to_buffers_size: @@ -419,11 +416,7 @@ def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[i "max_dq_buffer_size": 1, "max_inner_outer_dim": 1, } - - if not use_act_order: - submodule._use_act_order = False - else: - submodule._use_act_order = True + submodule._use_act_order = True if use_act_order else False # Disable this heuristic for detecting act_order, but it could be used instead of the config. """ @@ -447,6 +440,11 @@ def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[i submodule.infeatures, submodule.outfeatures, ) + elif isinstance(submodule, ExllamaV2QuantLinear): + model_uses_exllamav2 = True + device = submodule.qweight.device + scratch_fixed = submodule.scratch_space_fixed() + fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0)) if model_uses_exllama: # To be honest this is quite ugly, not proud of this. @@ -496,22 +494,6 @@ def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[i matmul_no_half2 = False set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - # The buffers need to have been initialized first before calling make_q4. - for name, submodule in model.named_modules(): - if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllama": - submodule.post_init() - - # exllamav2 - fixed_bytes = {} - model_uses_exllamav2 = False - - for _, submodule in model.named_modules(): - if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllamav2": - model_uses_exllamav2 = True - device = submodule.qweight.device - scratch_fixed = submodule.scratch_space_fixed() - fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0)) - if model_uses_exllamav2: from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors @@ -522,10 +504,14 @@ def gptqmodel_post_init(model, use_act_order: bool, max_input_length: Optional[i # have persistent buffers, otherwise we will get OOM model.device_tensors = device_tensors - for _, submodule in model.named_modules(): - if isinstance(submodule, BaseQuantLinear) and submodule.QUANT_TYPE == "exllamav2": - device = submodule.qweight.device - submodule.post_init(temp_dq=model.device_tensors[device]) + # The buffers need to have been initialized first before calling make_q4. + for _, submodule in model.named_modules(): + if isinstance(submodule, ExllamaV2QuantLinear): + device = submodule.qweight.device + submodule.post_init(temp_dq=model.device_tensors[device]) + elif isinstance(submodule, BaseQuantLinear): + submodule.post_init() + torch.cuda.empty_cache() return model