From 249bfcc9a8ef74d6dae3cf5aab11e918270ba344 Mon Sep 17 00:00:00 2001 From: Qubitium-modelcloud Date: Sun, 16 Jun 2024 13:55:31 +0800 Subject: [PATCH] V3 remove triton v1 (#15) * pkg depends update * remove fused attention/mlp * Remove triton v1 and cleanup unused fused files --- auto_gptq/modeling/_base.py | 30 +- auto_gptq/modeling/_utils.py | 9 +- auto_gptq/modeling/auto.py | 2 - auto_gptq/nn_modules/_fused_base.py | 36 -- auto_gptq/nn_modules/fused_gptj_attn.py | 314 --------------- auto_gptq/nn_modules/fused_llama_attn.py | 234 ----------- auto_gptq/nn_modules/fused_llama_mlp.py | 365 ------------------ .../nn_modules/qlinear/qlinear_triton.py | 217 ----------- auto_gptq/utils/import_utils.py | 10 +- auto_gptq/utils/peft_utils.py | 2 +- tests/test_triton.py | 20 +- 11 files changed, 14 insertions(+), 1225 deletions(-) delete mode 100644 auto_gptq/nn_modules/_fused_base.py delete mode 100644 auto_gptq/nn_modules/fused_gptj_attn.py delete mode 100644 auto_gptq/nn_modules/fused_llama_attn.py delete mode 100644 auto_gptq/nn_modules/fused_llama_mlp.py delete mode 100644 auto_gptq/nn_modules/qlinear/qlinear_triton.py diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index bb6d48923d055e..4a3d1d255367cd 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -790,7 +790,6 @@ def from_quantized( trainable: bool = False, disable_exllama: Optional[bool] = None, disable_exllamav2: bool = False, - use_tritonv2: bool = False, checkpoint_format: Optional[str] = None, **kwargs, ): @@ -828,15 +827,9 @@ def from_quantized( if use_qigen and not QIGEN_AVAILABLE: logger.warning("Qigen is not installed, reset use_qigen to False.") use_qigen = False - if use_triton and use_tritonv2: - logging.warn( - "Both use_triton and use_tritonv2 are set to True. Defaulting to use_triton" - ) - use_tritonv2 = False - if (use_triton or use_tritonv2) and not TRITON_AVAILABLE: + if use_triton and not TRITON_AVAILABLE: logger.warning("Triton is not installed, reset use_triton to False.") use_triton = False - use_tritonv2 = False if not disable_exllama and not EXLLAMA_KERNELS_AVAILABLE: logger.warning( "Exllama kernel is not installed, reset disable_exllama to True. " @@ -942,7 +935,7 @@ def from_quantized( disable_exllama = True disable_exllamav2 = True - elif not (use_triton or use_tritonv2) and trainable: + elif not use_triton and trainable: logger.warning( "QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend." ) @@ -1004,7 +997,6 @@ def skip(*args, **kwargs): use_cuda_fp16=use_cuda_fp16, desc_act=quantize_config.desc_act, trainable=trainable, - use_tritonv2=use_tritonv2, ) model.tie_weights() @@ -1051,7 +1043,6 @@ def skip(*args, **kwargs): bits=quantize_config.bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, - use_tritonv2=use_tritonv2, ) # TODO: move this logic in an awq_utils.py file. @@ -1172,7 +1163,6 @@ def skip(*args, **kwargs): disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=False, - use_tritonv2=use_tritonv2, # Get the "original" QuantLinear class ) # Prepare model for marlin load. @@ -1234,7 +1224,6 @@ def skip(*args, **kwargs): desc_act=quantize_config.desc_act, trainable=trainable, use_qigen=True, - use_tritonv2=use_tritonv2, use_marlin=quantize_config.checkpoint_format == CHECKPOINT_FORMAT.MARLIN, ) preprocess_checkpoint_qigen( @@ -1248,7 +1237,6 @@ def skip(*args, **kwargs): qlinear_kernel = dynamically_import_QuantLinear( use_triton=use_triton, - use_tritonv2=use_tritonv2, desc_act=quantize_config.desc_act, group_size=quantize_config.group_size, bits=quantize_config.bits, @@ -1293,12 +1281,8 @@ def skip(*args, **kwargs): model.eval() # == step6: (optional) warmup triton == # - if (use_triton or use_tritonv2) and warmup_triton: - if use_tritonv2: - from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear - else: - from ..nn_modules.qlinear.qlinear_triton import QuantLinear - + if use_triton and warmup_triton: + from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear QuantLinear.warmup(model, seqlen=model.seqlen) @@ -1319,7 +1303,7 @@ def skip(*args, **kwargs): model, True, quantize_config, - is_triton_backend=use_triton or use_tritonv2, + is_triton_backend=use_triton, trainable=trainable, qlinear_kernel=qlinear_kernel, ) @@ -1331,8 +1315,7 @@ def warmup_triton(self, enabled: bool = True): logger.warning("triton is not available, skip warmup stage directly.") return - from ..nn_modules.qlinear.qlinear_triton import QuantLinear - + from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear QuantLinear.warmup(self.model, seqlen=self.model.seqlen) def enable_trainable_mode(self, enabled: bool = True): @@ -1356,7 +1339,6 @@ def make_sure_compatible_with_peft( disable_exllamav2: bool = False, use_marlin: bool = False, use_qigen: bool = False, - use_tritonv2: bool = False, ): GeneralQuantLinear.inject_to_model( model, diff --git a/auto_gptq/modeling/_utils.py b/auto_gptq/modeling/_utils.py index 2708ed8ff257e3..f7030a410aff37 100644 --- a/auto_gptq/modeling/_utils.py +++ b/auto_gptq/modeling/_utils.py @@ -81,7 +81,6 @@ def make_quant( use_cuda_fp16: bool = True, desc_act: bool = False, trainable: bool = False, - use_tritonv2: bool = False, ): # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones. if disable_exllama is None: @@ -99,7 +98,6 @@ def make_quant( disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen, - use_tritonv2=use_tritonv2, ) if isinstance(module, QuantLinear): @@ -123,7 +121,6 @@ def make_quant( (not (desc_act) or group_size == -1) and not use_triton and not use_qigen - and not use_tritonv2 ): new_layer = QuantLinear( bits, @@ -331,7 +328,6 @@ def pack_model( warmup_triton: bool = False, force_layer_back_to_cpu: bool = False, use_marlin: bool = False, - use_tritonv2: bool = False, ): QuantLinear = dynamically_import_QuantLinear( use_triton=use_triton, @@ -341,7 +337,6 @@ def pack_model( disable_exllama=False, disable_exllamav2=True, use_marlin=use_marlin, - use_tritonv2=use_tritonv2, ) if force_layer_back_to_cpu: @@ -577,9 +572,9 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in def make_sure_no_tensor_in_meta_device( - model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False, use_tritonv2: bool = False, + model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False, ): - QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin, use_tritonv2=use_tritonv2) + QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin) for n, m in model.named_modules(): if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"): m.register_buffer("bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu")) diff --git a/auto_gptq/modeling/auto.py b/auto_gptq/modeling/auto.py index 56bf2f091b8638..e45d63cc502a25 100644 --- a/auto_gptq/modeling/auto.py +++ b/auto_gptq/modeling/auto.py @@ -110,7 +110,6 @@ def from_quantized( disable_exllama: Optional[bool] = None, disable_exllamav2: bool = False, use_marlin: bool = False, - use_tritonv2: bool = False, **kwargs, ) -> BaseGPTQForCausalLM: # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones. @@ -158,7 +157,6 @@ def from_quantized( disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin, - use_tritonv2=use_tritonv2, **keywords, ) diff --git a/auto_gptq/nn_modules/_fused_base.py b/auto_gptq/nn_modules/_fused_base.py deleted file mode 100644 index 7d899625cc8a36..00000000000000 --- a/auto_gptq/nn_modules/_fused_base.py +++ /dev/null @@ -1,36 +0,0 @@ -from abc import abstractmethod -from logging import getLogger - -import torch.nn as nn - -from .triton_utils.mixin import TritonModuleMixin - - -logger = getLogger(__name__) - - -class FusedBaseModule(nn.Module, TritonModuleMixin): - @classmethod - @abstractmethod - def inject_to_model(cls, *args, **kwargs): - raise NotImplementedError() - - -class FusedBaseAttentionModule(FusedBaseModule): - @classmethod - @abstractmethod - def inject_to_model( - cls, model, use_triton=False, group_size=-1, use_cuda_fp16=True, desc_act=False, trainable=False, **kwargs - ): - raise NotImplementedError() - - @classmethod - def warmup(cls, model, transpose=False, seqlen=2048): - pass - - -class FusedBaseMLPModule(FusedBaseModule): - @classmethod - @abstractmethod - def inject_to_model(cls, model, use_triton=False, **kwargs): - raise NotImplementedError() diff --git a/auto_gptq/nn_modules/fused_gptj_attn.py b/auto_gptq/nn_modules/fused_gptj_attn.py deleted file mode 100644 index 85f6c349f84feb..00000000000000 --- a/auto_gptq/nn_modules/fused_gptj_attn.py +++ /dev/null @@ -1,314 +0,0 @@ -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.nn import functional as F -from transformers.models.gptj.modeling_gptj import GPTJAttention - -from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear -from ._fused_base import FusedBaseAttentionModule - - -def fixed_pos_embedding(x, seq_dim=1, seq_len=None): - dim = x.shape[-1] - if seq_len is None: - seq_len = x.shape[seq_dim] - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) - sinusoid_inp = ( - torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float() - ) - return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) - - -def rotate_every_two(x): - x1 = x[:, :, :, ::2] - x2 = x[:, :, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') - - -def duplicate_interleave(m): - """ - A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. - """ - dim0 = m.shape[0] - m = m.view(-1, 1) # flatten the matrix - m = m.repeat(1, 2) # repeat all elements into the 2nd dimension - m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy - return m - - -def apply_rotary_pos_emb(x, sincos, offset=0): - sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos) - # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) - return (x * cos) + (rotate_every_two(x) * sin) - - -class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule): - def __init__(self, config): - super().__init__() - - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - 1, 1, max_positions, max_positions - ), - ) - self.register_buffer("masked_bias", torch.tensor(-1e9)) - - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.attn_dropout_p = config.attn_pdrop - self.resid_dropout = nn.Dropout(config.resid_pdrop) - - self.embed_dim = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_attention_heads - if self.head_dim * self.num_attention_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" - f" `num_attention_heads`: {self.num_attention_heads})." - ) - self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) - - self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) - self.rotary_dim = config.rotary_dim - - def _split_heads(self, qkv): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - new_shape = qkv.size()[:-1] + (3, self.num_attention_heads, self.head_dim) - qkv = qkv.view(new_shape) # (batch, seq_length, 3, head, head_features) - query = qkv[:, :, 0] - key = qkv[:, :, 1] - value = qkv[:, :, 2] - - return query, key, value - - def _merge_heads(self, tensor, num_attention_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden dim - """ - if len(tensor.shape) == 5: - tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() - elif len(tensor.shape) == 4: - tensor = tensor.permute(0, 2, 1, 3).contiguous() - else: - raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") - new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) - return tensor.view(new_shape) - - def _attn( - self, - query, - key, - value, - attention_mask=None, - head_mask=None, - ): - # compute causal mask from causal mask buffer - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - - # Keep the attention weights computation in fp32 to avoid overflow issues - query = query.to(torch.float32) - key = key.to(torch.float32) - - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - attn_weights = attn_weights / self.scale_attn - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - attn_weights = attn_weights.to(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - - def forward( - self, - hidden_states: torch.FloatTensor, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Union[ - Tuple[torch.Tensor, Tuple[torch.Tensor]], - Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], - ]: - query, key, value = self._split_heads(self.qkv_proj(hidden_states)) - - seq_len = key.shape[1] - offset = 0 - - if layer_past is not None: - offset = layer_past[0].shape[-2] - seq_len += offset - - if self.rotary_dim is not None: - k_rot = key[:, :, :, : self.rotary_dim] - k_pass = key[:, :, :, self.rotary_dim :] - - q_rot = query[:, :, :, : self.rotary_dim] - q_pass = query[:, :, :, self.rotary_dim :] - - sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) - k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) - q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) - - key = torch.cat([k_rot, k_pass], dim=-1) - query = torch.cat([q_rot, q_pass], dim=-1) - else: - sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) - key = apply_rotary_pos_emb(key, sincos, offset=offset) - query = apply_rotary_pos_emb(query, sincos, offset=offset) - - key = key.permute(0, 2, 1, 3) - query = query.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - - is_causal = layer_past is None - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - if use_cache is True: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - present = (key, value) - else: - present = None - - # compute self-attention: V x Softmax(QK^T) - if compare_pytorch_version("v2.0.0", op="ge"): - attn_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=None if is_causal else attention_mask, - dropout_p=self.attn_dropout_p, - is_causal=is_causal, - ) - attn_weights = None - else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) - attn_output = self.out_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - @classmethod - def inject_to_model( - cls, - model, - use_triton=False, - group_size=-1, - use_cuda_fp16=True, - desc_act=False, - trainable=False, - bits: int = 4, - disable_exllama=True, - disable_exllamav2=False, - **kwargs, - ): - config = model.config - QuantLinear = dynamically_import_QuantLinear( - use_triton=use_triton, - desc_act=desc_act, - group_size=group_size, - bits=bits, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, - ) - - for name, m in model.named_modules(): - if not isinstance(m, GPTJAttention): - continue - - attn = cls(config).to(device=next(m.buffers()).device) - - q_proj = m.q_proj - k_proj = m.k_proj - v_proj = m.v_proj - - qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) - - if QuantLinear.QUANT_TYPE == "exllama": - if desc_act: - # See fused_llama_attn.py comment - raise ValueError( - "Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True." - ) - else: - g_idx = None - else: - g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) - - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None - - qlinear_args = ( - q_proj.bits, - q_proj.group_size, - q_proj.infeatures, - q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, - True if q_proj.bias is not None else False, - ) - qlinear_kwargs = {"trainable": trainable} - if (not desc_act or group_size == -1) and not use_triton: - qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16 - qlinear_kwargs["weight_dtype"] = q_proj.scales.dtype - - qkv_proj = QuantLinear(*qlinear_args, **qlinear_kwargs) - qkv_proj.qweight = qweights - qkv_proj.qzeros = qzeros - qkv_proj.scales = scales - qkv_proj.g_idx = g_idx - qkv_proj.bias = bias - - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.get_submodule(parent_name) - else: - parent_name = "" - parent = model - child_name = name - - attn.qkv_proj = qkv_proj - attn.out_proj = m.out_proj - - setattr(parent, child_name, attn) - del m - - -__all__ = ["FusedGPTJAttentionForQuantizedModel"] diff --git a/auto_gptq/nn_modules/fused_llama_attn.py b/auto_gptq/nn_modules/fused_llama_attn.py deleted file mode 100644 index a2cde1b65f4845..00000000000000 --- a/auto_gptq/nn_modules/fused_llama_attn.py +++ /dev/null @@ -1,234 +0,0 @@ -import math - -import torch -import torch.nn as nn -from torch.nn import functional as F -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - apply_rotary_pos_emb, -) - -from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear -from ._fused_base import FusedBaseAttentionModule - - -class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - hidden_size, - num_heads, - qkv_proj, - o_proj, - rotary_emb, - layer_idx, - ): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.layer_idx = layer_idx - - if self.head_dim * num_heads != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - self.qkv_proj = qkv_proj - self.o_proj = o_proj - self.rotary_emb = rotary_emb - - def _shape(self, tensor, seq_len, bsz): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states, - past_key_value=None, - attention_mask=None, - position_ids=None, - output_attentions=False, - use_cache=False, - **kwargs, - ): - """Input shape: Batch x Time x Channel""" - - bsz, q_len, _ = hidden_states.size() - - qkv_states = self.qkv_proj(hidden_states) - query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index. Please open an issue in AutoGPTQ if you hit this." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if use_cache: - # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor - # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - if compare_pytorch_version("v2.0.0", op="ge"): - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=attention_mask is None and q_len > 1, - ) - attn_weights = None - else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - @classmethod - def inject_to_model( - cls, - model, - use_triton=False, - group_size=-1, - use_cuda_fp16=True, - desc_act=False, - trainable=False, - bits: int = 4, - disable_exllama=True, - disable_exllamav2=False, - **kwargs, - ): - """ - Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. - """ - QuantLinear = dynamically_import_QuantLinear( - use_triton=use_triton, - desc_act=desc_act, - group_size=group_size, - bits=bits, - disable_exllama=disable_exllama, - disable_exllamav2=disable_exllamav2, - ) - - for name, m in model.named_modules(): - if not isinstance(m, LlamaAttention): - continue - - q_proj = m.q_proj - k_proj = m.k_proj - v_proj = m.v_proj - - qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) - - if QuantLinear.QUANT_TYPE == "exllama": - if desc_act: - # TODO: support it. The issue lies maybe in the line: - # int groups = qzeros.size(0); - # in exllama_ext.cpp - raise ValueError( - "Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True." - ) - else: - g_idx = None - else: - g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) - - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None - - qlinear_args = ( - q_proj.bits, - q_proj.group_size, - q_proj.infeatures, - q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, - True if q_proj.bias is not None else False, - ) - qlinear_kwargs = {"trainable": trainable} - if (not desc_act or group_size == -1) and not use_triton: - qlinear_kwargs["use_cuda_fp16"] = use_cuda_fp16 - qlinear_kwargs["weight_dtype"] = q_proj.scales.dtype - - qkv_layer = QuantLinear(*qlinear_args, **qlinear_kwargs) - qkv_layer.qweight = qweights - qkv_layer.qzeros = qzeros - qkv_layer.scales = scales - qkv_layer.g_idx = g_idx - qkv_layer.bias = bias - - # Introduced in Transformers 4.36 - layer_idx = None - if hasattr(m, "layer_idx"): - layer_idx = m.layer_idx - attn = cls( - m.hidden_size, - m.num_heads, - qkv_layer, - m.o_proj, - m.rotary_emb, - layer_idx=layer_idx, - ) - - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.get_submodule(parent_name) - else: - parent_name = "" - parent = model - child_name = name - - setattr(parent, child_name, attn) - - -__all__ = ["FusedLlamaAttentionForQuantizedModel"] diff --git a/auto_gptq/nn_modules/fused_llama_mlp.py b/auto_gptq/nn_modules/fused_llama_mlp.py deleted file mode 100644 index f20df0d0226678..00000000000000 --- a/auto_gptq/nn_modules/fused_llama_mlp.py +++ /dev/null @@ -1,365 +0,0 @@ -import math -from logging import getLogger - -import torch -from transformers.models.llama.modeling_llama import LlamaMLP - -from ..utils.import_utils import TRITON_AVAILABLE -from ._fused_base import FusedBaseMLPModule - - -logger = getLogger(__name__) - -if TRITON_AVAILABLE: - import triton - import triton.language as tl - - from .triton_utils import custom_autotune - from .triton_utils.kernels import silu - - @custom_autotune.autotune( - configs=[ - triton.Config( - { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), # 3090 - triton.Config( - { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), # 3090 - triton.Config( - { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - }, - num_stages=2, - num_warps=4, - ), # 3090 - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), # 3090 - triton.Config( - { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 8, - }, - num_stages=4, - num_warps=4, - ), # 3090 - ], - key=["M", "N", "K"], - nearest_power_of_two=True, - prune_configs_by={ - "early_config_prune": custom_autotune.matmul248_kernel_config_pruner, - "perf_model": None, - "top_k": None, - }, - ) - @triton.jit - def quant_fused_matmul_248_kernel( - a_ptr, - c_ptr, - b1_ptr, - scales1_ptr, - zeros1_ptr, - g1_ptr, - b2_ptr, - scales2_ptr, - zeros2_ptr, - g2_ptr, - M, - N, - K, - bits, - maxq, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_scales, - stride_zeros, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ): - """ - Computes: C = silu(A * B1) * (A * B2) - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (1, N) float16 - zeros is of shape (1, N//8) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = offs_am[:, None] < M - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) - b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) - g1_ptrs = g1_ptr + offs_k - g2_ptrs = g2_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales1_ptrs = scales1_ptr + offs_bn[None, :] - scales2_ptrs = scales2_ptr + offs_bn[None, :] - zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits) - zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, num_pid_k): - g1_idx = tl.load(g1_ptrs) - g2_idx = tl.load(g2_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales) - - zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq - zeros1 = zeros1 + 1 - - zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq - zeros2 = zeros2 + 1 - - a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - b2 = tl.load(b2_ptrs) - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values - b1 = (b1 - zeros1) * scales1 # Scale and shift - accumulator1 += tl.dot(a, b1) - - b2 = (b2 >> shifter[:, None]) & maxq - b2 = (b2 - zeros2) * scales2 - accumulator2 += tl.dot(a, b2) - - a_ptrs += BLOCK_SIZE_K - b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g1_ptrs += BLOCK_SIZE_K - g2_ptrs += BLOCK_SIZE_K - - accumulator1 = silu(accumulator1) - c = accumulator1 * accumulator2 - c = c.to(tl.float16) - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - -else: - quant_fused_matmul_248_kernel = None - - -class FusedLlamaMLPForQuantizedModel(FusedBaseMLPModule): - def __init__( - self, - gate_proj, - down_proj, - up_proj, - ): - super().__init__() - - self.infeatures = gate_proj.infeatures - self.intermediate_size = gate_proj.outfeatures - self.outfeatures = down_proj.outfeatures - self.bits = gate_proj.bits - self.maxq = gate_proj.maxq - - self.gate_proj = gate_proj - self.up_proj = up_proj - self.down_proj = down_proj - - def forward(self, x): - return self.down_proj(self.triton_llama_mlp(x)) - - def triton_llama_mlp(self, x): - with torch.cuda.device(x.device): - out_shape = x.shape[:-1] + (self.intermediate_size,) - x = x.reshape(-1, x.shape[-1]) - M, K = x.shape - N = self.intermediate_size - c = torch.empty((M, N), device=x.device, dtype=torch.float16) - grid = lambda META: ( # noqa: E731 - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - quant_fused_matmul_248_kernel[grid]( - x, - c, - self.gate_proj.qweight, - self.gate_proj.scales, - self.gate_proj.qzeros, - self.gate_proj.g_idx, - self.up_proj.qweight, - self.up_proj.scales, - self.up_proj.qzeros, - self.up_proj.g_idx, - M, - N, - K, - self.bits, - self.maxq, - x.stride(0), - x.stride(1), - self.gate_proj.qweight.stride(0), - self.gate_proj.qweight.stride(1), - c.stride(0), - c.stride(1), - self.gate_proj.scales.stride(0), - self.gate_proj.qzeros.stride(0), - ) - c = c.reshape(out_shape) - return c - - @classmethod - def inject_to_model(cls, model, use_triton=False, **kwargs): - if not use_triton: - logger.warning( - f"Skipping module injection for {cls.__name__} as currently not supported with use_triton=False." - ) - return - elif not TRITON_AVAILABLE: - logger.warning( - f"Skipping module injection for {cls.__name__} as Triton is not available. Please check your installation." - ) - return - - for name, m in model.named_modules(): - if not isinstance(m, LlamaMLP): - continue - - mlp = cls(m.gate_proj, m.down_proj, m.up_proj) - - if "." in name: - parent_name = name.rsplit(".", 1)[0] - child_name = name[len(parent_name) + 1 :] - parent = model.get_submodule(parent_name) - else: - parent_name = "" - parent = model - child_name = name - - setattr(parent, child_name, mlp) - - @classmethod - def warmup(cls, model, transpose=False, seqlen=2048): - from tqdm import tqdm - - kn_values = {} - - for _, m in model.named_modules(): - if not isinstance(m, cls): - continue - - k = m.infeatures - n = m.intermediate_size - - if (k, n) not in kn_values: - kn_values[(k, n)] = m - - logger.info(f"Found {len(kn_values)} unique fused mlp KN values.") - logger.info("Warming up autotune cache ...") - with torch.no_grad(): - for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): - m = 2**m - for (k, n), (modules) in kn_values.items(): - a = torch.randn(m, k, dtype=torch.float16, device=model.device) - modules.triton_llama_mlp(a) - del kn_values - - -__all__ = ["FusedLlamaMLPForQuantizedModel"] diff --git a/auto_gptq/nn_modules/qlinear/qlinear_triton.py b/auto_gptq/nn_modules/qlinear/qlinear_triton.py deleted file mode 100644 index c65075b24276f3..00000000000000 --- a/auto_gptq/nn_modules/qlinear/qlinear_triton.py +++ /dev/null @@ -1,217 +0,0 @@ -import math -from logging import getLogger - -import numpy as np -import torch -import torch.nn as nn -import transformers - -from ..triton_utils.mixin import TritonModuleMixin - - -logger = getLogger(__name__) - -try: - from ..triton_utils.kernels import ( - QuantLinearFunction, - QuantLinearInferenceOnlyFunction, - quant_matmul_248, - quant_matmul_inference_only_248, - transpose_quant_matmul_248, - ) -except ImportError as e: - triton_import_exception = e - - def error_raiser_triton(*args, **kwargs): - raise ValueError( - f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" - ) - - class FakeTriton: - def __getattr__(self, name): - raise ImportError( - f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}" - ) - - quant_matmul_248 = error_raiser_triton - transpose_quant_matmul_248 = error_raiser_triton - quant_matmul_inference_only_248 = error_raiser_triton - QuantLinearFunction = FakeTriton - QuantLinearInferenceOnlyFunction = FakeTriton - - -class QuantLinear(nn.Module, TritonModuleMixin): - QUANT_TYPE = "triton" - - def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): - super().__init__() - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - if infeatures % 32 != 0 or outfeatures % 32 != 0: - raise NotImplementedError("in_feature and out_feature must be divisible by 32.") - self.infeatures = infeatures - self.outfeatures = outfeatures - self.bits = bits - self.group_size = group_size if group_size != -1 else infeatures - self.maxq = 2**self.bits - 1 - - self.register_buffer( - "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), - dtype=torch.float16, - ), - ) - self.register_buffer( - "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), - ) - if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) - else: - self.bias = None - - self.trainable = trainable - - def post_init(self): - pass - - def pack(self, linear, scales, zeros, g_idx=None): - W = linear.weight.data.clone() - if isinstance(linear, nn.Conv2d): - W = W.flatten(1) - if isinstance(linear, transformers.pytorch_utils.Conv1D): - W = W.t() - - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[ - :, None - ] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - - i = 0 - row = 0 - qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures,) - quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction - out = quant_linear_fn.apply( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - self.g_idx, - self.bits, - self.maxq, - ) - out = out.half().reshape(out_shape) - out = out + self.bias if self.bias is not None else out - return out - - @classmethod - def warmup(cls, model, transpose=False, seqlen=2048): - """ - Pre-tunes the quantized kernel - """ - from tqdm import tqdm - - kn_values = {} - - for _, m in model.named_modules(): - if not isinstance(m, cls): - continue - - k = m.infeatures - n = m.outfeatures - - if (k, n) not in kn_values: - kn_values[(k, n)] = ( - m.qweight, - m.scales, - m.qzeros, - m.g_idx, - m.bits, - m.maxq, - ) - - logger.info(f"Found {len(kn_values)} unique KN Linear values.") - logger.info("Warming up autotune cache ...") - with torch.no_grad(): - for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)): - m = 2**m - for (k, n), ( - qweight, - scales, - qzeros, - g_idx, - bits, - maxq, - ) in kn_values.items(): - if transpose: - a = torch.randn(m, k, dtype=torch.float16, device=model.device) - quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) - a = torch.randn(m, n, dtype=torch.float16, device=model.device) - transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq) - else: - a = torch.randn(m, k, dtype=torch.float16, device=model.device) - quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq) - del kn_values - - -__all__ = ["QuantLinear"] diff --git a/auto_gptq/utils/import_utils.py b/auto_gptq/utils/import_utils.py index fe9aabf7db9aae..0e190d10bdc3cf 100644 --- a/auto_gptq/utils/import_utils.py +++ b/auto_gptq/utils/import_utils.py @@ -74,16 +74,14 @@ def dynamically_import_QuantLinear( ) from ..nn_modules.qlinear.qlinear_qigen import QuantLinear else: - if use_triton or use_tritonv2: + if use_triton: if torch.version.hip: logger.warning( "Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False." ) - if use_tritonv2: - logger.debug("Using tritonv2 for GPTQ") - from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear - else: - from ..nn_modules.qlinear.qlinear_triton import QuantLinear + + logger.debug("Using tritonv2 for GPTQ") + from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear else: # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones. if disable_exllama is None: diff --git a/auto_gptq/utils/peft_utils.py b/auto_gptq/utils/peft_utils.py index 6030ba7419b211..88b4aa2fb4874c 100644 --- a/auto_gptq/utils/peft_utils.py +++ b/auto_gptq/utils/peft_utils.py @@ -16,7 +16,7 @@ from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllama from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as QuantLinearExllamaV2 from ..nn_modules.qlinear.qlinear_qigen import QuantLinear as QuantLinearQigen -from ..nn_modules.qlinear.qlinear_triton import QuantLinear as QuantLinearTriton +from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as QuantLinearTriton LinearLayer = Union[ diff --git a/tests/test_triton.py b/tests/test_triton.py index de2286ad027f83..05b42a0685c9a1 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -51,8 +51,6 @@ def amp_wrapper(*inputs, **kwinputs): def get_model_and_tokenizer( model_id=MODEL_ID, - inject_fused_attention=False, - inject_fused_mlp=False, **model_kwargs, ): tokenizer = AutoTokenizer.from_pretrained( @@ -65,8 +63,6 @@ def get_model_and_tokenizer( model = AutoGPTQForCausalLM.from_quantized( model_id, trainable=True, - inject_fused_attention=inject_fused_attention, - inject_fused_mlp=inject_fused_mlp, disable_exllamav2=True, disable_exllama=True, **model_kwargs, @@ -81,27 +77,13 @@ def test_triton_qlinear(self): ref_model, _ = get_model_and_tokenizer( model_id=MODEL_ID, use_triton=True, - inject_fused_attention=False, - inject_fused_mlp=False, - ) - test_model, _ = get_model_and_tokenizer( - model_id=MODEL_ID, - use_tritonv2=True, - inject_fused_attention=False, - inject_fused_mlp=False, ) + hidden_size = ref_model.model.model.embed_tokens.weight.shape[1] test_data = torch.randn((1, 2048, hidden_size), dtype=torch.float16).cuda() qlinear_ref = ref_model.model.model.layers[0].self_attn.q_proj - qlinear_test = test_model.model.model.layers[0].self_attn.q_proj - test_out = qlinear_test(test_data) ref_out = qlinear_ref(test_data) - self.assertTrue(torch.allclose(test_out, ref_out)) - _, measure_triton = benchmark_forward(qlinear_ref, test_data, desc="Triton", verbose=True) - _, measure_tritonv2 = benchmark_forward(qlinear_test, test_data, desc="Triton-v2", verbose=True) - - self.assertTrue(measure_tritonv2.mean < measure_triton.mean)