Skip to content

Commit

Permalink
V3 remove fused (huggingface#14)
Browse files Browse the repository at this point in the history
* pkg depends update

* remove fused attention/mlp
  • Loading branch information
Qubitium authored Jun 16, 2024
1 parent 00ba235 commit 1965b96
Show file tree
Hide file tree
Showing 10 changed files with 12 additions and 132 deletions.
51 changes: 0 additions & 51 deletions auto_gptq/modeling/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
create_repo,
)

from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModule
from ..nn_modules.qlinear import GeneralQuantLinear
from ..quantization import GPTQ, BaseQuantizeConfig
from ..quantization.config import (
Expand Down Expand Up @@ -98,17 +97,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
inside_layer_modules: List[List[str]] = None
lm_head_name: str = "lm_head"

fused_attn_module_type: Optional[FusedBaseAttentionModule] = None
fused_mlp_module_type: Optional[FusedBaseMLPModule] = None

def __init__(
self,
model: PreTrainedModel,
quantized: bool,
quantize_config: BaseQuantizeConfig,
is_triton_backend: bool = False,
injected_fused_attention: bool = False,
injected_fused_mlp: bool = False,
trainable: bool = False,
qlinear_kernel: nn.Module = None
):
Expand All @@ -121,8 +115,6 @@ def __init__(
self.config = self.model.config

self.is_triton_backend = is_triton_backend
self.injected_fused_attention = injected_fused_attention
self.injected_fused_mlp = injected_fused_mlp
self.trainable = trainable

# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
Expand Down Expand Up @@ -789,8 +781,6 @@ def from_quantized(
use_qigen: bool = False,
use_marlin: bool = False,
torch_dtype: Optional[torch.dtype] = None,
inject_fused_attention: bool = False,
inject_fused_mlp: bool = False,
use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
Expand Down Expand Up @@ -876,8 +866,6 @@ def from_quantized(

if use_qigen and QIGEN_AVAILABLE:
logger.warning("QIgen is active. Ignores all settings related to cuda.")
inject_fused_attention = False
inject_fused_mlp = False
use_triton = False
disable_exllama = True
disable_exllamav2 = True
Expand Down Expand Up @@ -1200,13 +1188,6 @@ def skip(*args, **kwargs):
device_map=device_map,
)

# Disable incompatible optimizations.
if inject_fused_attention or inject_fused_mlp:
# TODO: Validate whether that can be used.
logger.info("Disabling fused attention and mlp injection because Marlin kernel is used.")
inject_fused_attention = False
inject_fused_mlp = False

accelerate.utils.modeling.load_checkpoint_in_model(
model,
dtype=torch_dtype, # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
Expand Down Expand Up @@ -1306,31 +1287,6 @@ def skip(*args, **kwargs):
logger.warning("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096

# == step5: (optional) inject optimized module == #
if inject_fused_attention:
if cls.fused_attn_module_type is None:
inject_fused_attention = False
logger.warning(f"{cls.__name__} hasn't fused attention module yet, will skip inject fused attention.")
else:
cls.fused_attn_module_type.inject_to_model(
model,
use_triton=use_triton,
group_size=quantize_config.group_size,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable,
bits=quantize_config.bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_tritonv2=use_tritonv2,
)
if inject_fused_mlp:
if cls.fused_mlp_module_type is None:
inject_fused_mlp = False
logger.warning(f"{cls.__name__} hasn't fused mlp module yet, will skip inject fused mlp.")
else:
cls.fused_mlp_module_type.inject_to_model(model, use_triton=use_triton)

# Any post-initialization that require device information, for example buffers initialization on device.
model = autogptq_post_init(model, use_act_order=quantize_config.desc_act)

Expand All @@ -1345,8 +1301,6 @@ def skip(*args, **kwargs):

QuantLinear.warmup(model, seqlen=model.seqlen)

if inject_fused_mlp and cls.fused_mlp_module_type is not None:
cls.fused_mlp_module_type.warmup(model, seqlen=model.seqlen)

# == step7: make model compatible with peft
# cls.make_sure_compatible_with_peft(
Expand All @@ -1366,8 +1320,6 @@ def skip(*args, **kwargs):
True,
quantize_config,
is_triton_backend=use_triton or use_tritonv2,
injected_fused_attention=inject_fused_attention,
injected_fused_mlp=inject_fused_mlp and (use_triton or use_tritonv2),
trainable=trainable,
qlinear_kernel=qlinear_kernel,
)
Expand All @@ -1383,9 +1335,6 @@ def warmup_triton(self, enabled: bool = True):

QuantLinear.warmup(self.model, seqlen=self.model.seqlen)

if self.fused_mlp_module_type is not None:
self.fused_mlp_module_type.warmup(self.model, seqlen=self.model.seqlen)

def enable_trainable_mode(self, enabled: bool = True):
if not self.is_triton_backend and enabled:
raise NotImplementedError("For now, trainable mode only supports triton backend.")
Expand Down
4 changes: 0 additions & 4 deletions auto_gptq/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def from_quantized(
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
inject_fused_attention: bool = False,
inject_fused_mlp: bool = False,
use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
Expand Down Expand Up @@ -150,8 +148,6 @@ def from_quantized(
device=device,
low_cpu_mem_usage=low_cpu_mem_usage,
use_triton=use_triton,
inject_fused_attention=inject_fused_attention,
inject_fused_mlp=inject_fused_mlp,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
Expand Down
10 changes: 0 additions & 10 deletions auto_gptq/modeling/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class DeciLMGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["DeciLMGPTQForCausalLM"]
3 changes: 0 additions & 3 deletions auto_gptq/modeling/gptj.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel
from ._base import BaseGPTQForCausalLM


Expand All @@ -13,7 +12,5 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.fc_out"],
]

fused_attn_module_type = FusedGPTJAttentionForQuantizedModel


__all__ = ["GPTJGPTQForCausalLM"]
11 changes: 0 additions & 11 deletions auto_gptq/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@
from ..utils.import_utils import compare_transformers_version
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +17,5 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["LlamaGPTQForCausalLM"]
10 changes: 0 additions & 10 deletions auto_gptq/modeling/longllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class LongLlamaGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["LongLlamaGPTQForCausalLM"]
10 changes: 0 additions & 10 deletions auto_gptq/modeling/stablelmepoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class StableLMEpochGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["StableLMEpochGPTQForCausalLM"]
10 changes: 0 additions & 10 deletions auto_gptq/modeling/xverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class XverseGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["XverseGPTQForCausalLM"]
10 changes: 0 additions & 10 deletions auto_gptq/modeling/yi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class YiGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["YiGPTQForCausalLM"]
25 changes: 12 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,18 @@ def detect_local_sm_architectures():
common_setup_kwargs['version'] += f"+cu{CUDA_VERSION}"

requirements = [
"accelerate>=0.29.2",
"datasets",
"sentencepiece",
"numpy",
"rouge",
"gekko",
"torch>=1.13.0",
"safetensors>=0.3.1",
"transformers>=4.31.0",
"peft>=0.5.0",
"tqdm",
"threadpoolctl",
"packaging",
"accelerate>=0.31.0",
"datasets>=2.20.0",
"sentencepiece>=0.2.0",
"numpy>=1.26.4",
"rouge>=1.0.1",
"gekko>=1.1.1",
"torch>=2.3.1",
"safetensors>=0.4.3",
"transformers>=4.41.2",
"tqdm>=4.66.4",
"threadpoolctl>=3.5.0",
"packaging>=24.1",
]

extras_require = {
Expand Down

0 comments on commit 1965b96

Please sign in to comment.