Skip to content

Commit

Permalink
V3 remove triton v1 (huggingface#15)
Browse files Browse the repository at this point in the history
* pkg depends update

* remove fused attention/mlp

* Remove triton v1 and cleanup unused fused files
  • Loading branch information
Qubitium authored Jun 16, 2024
1 parent 1965b96 commit 249bfcc
Show file tree
Hide file tree
Showing 11 changed files with 14 additions and 1,225 deletions.
30 changes: 6 additions & 24 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,6 @@ def from_quantized(
trainable: bool = False,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_tritonv2: bool = False,
checkpoint_format: Optional[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -828,15 +827,9 @@ def from_quantized(
if use_qigen and not QIGEN_AVAILABLE:
logger.warning("Qigen is not installed, reset use_qigen to False.")
use_qigen = False
if use_triton and use_tritonv2:
logging.warn(
"Both use_triton and use_tritonv2 are set to True. Defaulting to use_triton"
)
use_tritonv2 = False
if (use_triton or use_tritonv2) and not TRITON_AVAILABLE:
if use_triton and not TRITON_AVAILABLE:
logger.warning("Triton is not installed, reset use_triton to False.")
use_triton = False
use_tritonv2 = False
if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE:
logger.warning(
"Exllama kernel is not installed, reset disable_exllama to True. "
Expand Down Expand Up @@ -942,7 +935,7 @@ def from_quantized(
disable_exllama = True
disable_exllamav2 = True

elif not (use_triton or use_tritonv2) and trainable:
elif not use_triton and trainable:
logger.warning(
"QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend."
)
Expand Down Expand Up @@ -1004,7 +997,6 @@ def skip(*args, **kwargs):
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable,
use_tritonv2=use_tritonv2,
)
model.tie_weights()

Expand Down Expand Up @@ -1051,7 +1043,6 @@ def skip(*args, **kwargs):
bits=quantize_config.bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_tritonv2=use_tritonv2,
)

# TODO: move this logic in an awq_utils.py file.
Expand Down Expand Up @@ -1172,7 +1163,6 @@ def skip(*args, **kwargs):
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=False,
use_tritonv2=use_tritonv2, # Get the "original" QuantLinear class
)

# Prepare model for marlin load.
Expand Down Expand Up @@ -1234,7 +1224,6 @@ def skip(*args, **kwargs):
desc_act=quantize_config.desc_act,
trainable=trainable,
use_qigen=True,
use_tritonv2=use_tritonv2,
use_marlin=quantize_config.checkpoint_format == CHECKPOINT_FORMAT.MARLIN,
)
preprocess_checkpoint_qigen(
Expand All @@ -1248,7 +1237,6 @@ def skip(*args, **kwargs):

qlinear_kernel = dynamically_import_QuantLinear(
use_triton=use_triton,
use_tritonv2=use_tritonv2,
desc_act=quantize_config.desc_act,
group_size=quantize_config.group_size,
bits=quantize_config.bits,
Expand Down Expand Up @@ -1293,12 +1281,8 @@ def skip(*args, **kwargs):
model.eval()

# == step6: (optional) warmup triton == #
if (use_triton or use_tritonv2) and warmup_triton:
if use_tritonv2:
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
else:
from ..nn_modules.qlinear.qlinear_triton import QuantLinear

if use_triton and warmup_triton:
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
QuantLinear.warmup(model, seqlen=model.seqlen)


Expand All @@ -1319,7 +1303,7 @@ def skip(*args, **kwargs):
model,
True,
quantize_config,
is_triton_backend=use_triton or use_tritonv2,
is_triton_backend=use_triton,
trainable=trainable,
qlinear_kernel=qlinear_kernel,
)
Expand All @@ -1331,8 +1315,7 @@ def warmup_triton(self, enabled: bool = True):
logger.warning("triton is not available, skip warmup stage directly.")
return

from ..nn_modules.qlinear.qlinear_triton import QuantLinear

from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
QuantLinear.warmup(self.model, seqlen=self.model.seqlen)

def enable_trainable_mode(self, enabled: bool = True):
Expand All @@ -1356,7 +1339,6 @@ def make_sure_compatible_with_peft(
disable_exllamav2: bool = False,
use_marlin: bool = False,
use_qigen: bool = False,
use_tritonv2: bool = False,
):
GeneralQuantLinear.inject_to_model(
model,
Expand Down
9 changes: 2 additions & 7 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def make_quant(
use_cuda_fp16: bool = True,
desc_act: bool = False,
trainable: bool = False,
use_tritonv2: bool = False,
):
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
if disable_exllama is None:
Expand All @@ -99,7 +98,6 @@ def make_quant(
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen,
use_tritonv2=use_tritonv2,
)

if isinstance(module, QuantLinear):
Expand All @@ -123,7 +121,6 @@ def make_quant(
(not (desc_act) or group_size == -1)
and not use_triton
and not use_qigen
and not use_tritonv2
):
new_layer = QuantLinear(
bits,
Expand Down Expand Up @@ -331,7 +328,6 @@ def pack_model(
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False,
use_marlin: bool = False,
use_tritonv2: bool = False,
):
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
Expand All @@ -341,7 +337,6 @@ def pack_model(
disable_exllama=False,
disable_exllamav2=True,
use_marlin=use_marlin,
use_tritonv2=use_tritonv2,
)

if force_layer_back_to_cpu:
Expand Down Expand Up @@ -577,9 +572,9 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in


def make_sure_no_tensor_in_meta_device(
model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False, use_tritonv2: bool = False,
model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False,
):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin, use_tritonv2=use_tritonv2)
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin)
for n, m in model.named_modules():
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
m.register_buffer("bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"))
Expand Down
2 changes: 0 additions & 2 deletions auto_gptq/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def from_quantized(
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_marlin: bool = False,
use_tritonv2: bool = False,
**kwargs,
) -> BaseGPTQForCausalLM:
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
Expand Down Expand Up @@ -158,7 +157,6 @@ def from_quantized(
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
use_tritonv2=use_tritonv2,
**keywords,
)

Expand Down
36 changes: 0 additions & 36 deletions auto_gptq/nn_modules/_fused_base.py

This file was deleted.

Loading

0 comments on commit 249bfcc

Please sign in to comment.