Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkey patch HF transformer/optimum/peft support #818

Merged
merged 31 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4c62a76
add peft
CSY-ModelCloud Dec 10, 2024
9d58131
add optimum
CSY-ModelCloud Dec 10, 2024
5a75452
fix import
CSY-ModelCloud Dec 10, 2024
065b14e
update optimun
CSY-ModelCloud Dec 10, 2024
663d4d3
fix optimun import
CSY-ModelCloud Dec 10, 2024
11a9b26
add transformers
CSY-ModelCloud Dec 10, 2024
3e70863
fix transformers import
CSY-ModelCloud Dec 10, 2024
091b594
add patch
CSY-ModelCloud Dec 10, 2024
f972119
add patch
CSY-ModelCloud Dec 10, 2024
395912c
fix patch
CSY-ModelCloud Dec 10, 2024
3514852
fix patch imports
CSY-ModelCloud Dec 10, 2024
fd37bb3
add prefix for imports & move to other dirs
CSY-ModelCloud Dec 10, 2024
46a4223
fix GPTQQuantizer patch
CSY-ModelCloud Dec 11, 2024
257bea2
fix patch error
CSY-ModelCloud Dec 11, 2024
d4005e1
remove unused
CSY-ModelCloud Dec 11, 2024
41fdbca
check if lib is installed
CSY-ModelCloud Dec 11, 2024
7e115e6
replace all for transformers
CSY-ModelCloud Dec 11, 2024
d924bcd
add ExllamaVersion patch
CSY-ModelCloud Dec 11, 2024
39f6ca8
add missing patch
CSY-ModelCloud Dec 11, 2024
3c5e007
patch transformers_testing_utils
CSY-ModelCloud Dec 11, 2024
1d9766a
patch GPTQConfig
CSY-ModelCloud Dec 11, 2024
3a27d04
update init.py
CSY-ModelCloud Dec 11, 2024
9ebf031
delete unused
CSY-ModelCloud Dec 11, 2024
f412372
add another patch
CSY-ModelCloud Dec 11, 2024
d9d9e89
Merge remote-tracking branch 'origin/main' into monkey-patch
CSY-ModelCloud Dec 12, 2024
53f6fc6
fix patch
CSY-ModelCloud Dec 12, 2024
2882e0e
fix ruff
CSY-ModelCloud Dec 12, 2024
dd8e5e0
fix ruff
CSY-ModelCloud Dec 12, 2024
69d57cb
check cuda when there's only cuda device
CSY-ModelCloud Dec 13, 2024
e61b108
exclude triton
CSY-ModelCloud Dec 13, 2024
59cb8d2
set SUPPORTS_TRAINING to exclude
CSY-ModelCloud Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions gptqmodel/integration/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
HAS_OPTIMUM = True
try:
import optimum.gptq as optimum_gptq
from optimum.gptq import quantizer as optimum_quantizer
from optimum.utils import import_utils as optimum_import_utils
from optimum.utils import testing_utils as optimum_testing_utils

from .src.optimum.gptq import quantizer as patched_optimum_quantizer
from .src.optimum.utils import import_utils as patched_optimum_import_utils
from .src.optimum.utils import testing_utils as patched_optimum_testing_utils
except BaseException:
HAS_OPTIMUM = False

HAS_PEFT = True
try:
from peft import import_utils as peft_import_utils
from peft.tuners.adalora.model import AdaLoraModel as peft_AdaLoraModel
from peft.tuners.lora import gptq as peft_gptq
from peft.tuners.lora import model as peft_model
from peft.utils import other as peft_other

from .src.peft import import_utils as patched_peft_import_utils
from .src.peft.tuners.adalora.model import AdaLoraModel as patched_peft_AdaLoraModel
from .src.peft.tuners.lora import gptq as patched_peft_gptq
from .src.peft.tuners.lora import model as patched_peft_model
from .src.peft.utils import other as patched_peft_other
except BaseException:
HAS_PEFT = False

import transformers.testing_utils as transformers_testing_utils # noqa: E402
from transformers.quantizers import quantizer_gptq as transformers_quantizer_gptq # noqa: E402
from transformers.utils import import_utils as transformers_import_utils # noqa: E402
from transformers.utils import quantization_config as transformers_quantization_config # noqa: E402

from .src.transformers import testing_utils as patched_transformers_testing_utils # noqa: E402
from .src.transformers.quantizers import quantizer_gptq as patched_transformers_quantizer_gptq # noqa: E402
from .src.transformers.utils import import_utils as patched_transformers_import_utils # noqa: E402
from .src.transformers.utils import quantization_config as patched_transformers_quantization_config # noqa: E402


def patch_hf():
_patch_peft()
_patch_optimum()
_patch_transformers()


def _patch_peft():
if not HAS_PEFT:
return

peft_import_utils.is_gptqmodel_available = patched_peft_import_utils.is_gptqmodel_available

peft_AdaLoraModel._create_and_replace = patched_peft_AdaLoraModel._create_and_replace

peft_gptq.dispatch_gptq = patched_peft_gptq.dispatch_gptq

peft_model.LoraModel = patched_peft_model.LoraModel

peft_other.get_auto_gptq_quant_linear = patched_peft_other.get_auto_gptq_quant_linear
peft_other.get_gptqmodel_quant_linear = patched_peft_other.get_gptqmodel_quant_linear


def _patch_optimum():
if not HAS_OPTIMUM:
return

optimum_gptq.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer
optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available
optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu
optimum_quantizer.ExllamaVersion = patched_optimum_quantizer.ExllamaVersion

optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available
optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available
optimum_testing_utils.require_gptq = patched_optimum_testing_utils.require_gptq


def _patch_transformers():
transformers_quantizer_gptq.GptqHfQuantizer.required_packages = patched_transformers_quantizer_gptq.GptqHfQuantizer.required_packages
transformers_quantizer_gptq.GptqHfQuantizer.validate_environment = patched_transformers_quantizer_gptq.GptqHfQuantizer.validate_environment
transformers_quantizer_gptq.GptqHfQuantizer._process_model_before_weight_loading = patched_transformers_quantizer_gptq.GptqHfQuantizer._process_model_before_weight_loading

transformers_import_utils._gptqmodel_available = patched_transformers_import_utils._gptqmodel_available
transformers_import_utils.is_gptqmodel_available = patched_transformers_import_utils.is_gptqmodel_available

transformers_quantization_config.GPTQConfig.__init__ = patched_transformers_quantization_config.GPTQConfig.__init__
transformers_quantization_config.GPTQConfig.post_init = patched_transformers_quantization_config.GPTQConfig.post_init

transformers_testing_utils.require_gptq = patched_transformers_testing_utils.require_gptq

# if 'transformers.quantizers.quantizer_gptq' in sys.modules:
# del sys.modules['transformers.quantizers.quantizer_gptq']
# sys.modules['transformers.quantizers.quantizer_gptq'] = patched_transformers_quantizer_gptq
#
# if 'transformers.utils.import_utils' in sys.modules:
# del sys.modules['transformers.utils.import_utils']
# sys.modules['transformers.utils.import_utils'] = patched_transformers_import_utils
#
# if 'transformers.utils.quantization_config' in sys.modules:
# del sys.modules['transformers.utils.quantization_config']
# sys.modules['transformers.utils.quantization_config'] = patched_transformers_quantization_config
#
# if 'transformers.testing_utils' in sys.modules:
# del sys.modules['transformers.testing_utils']
# sys.modules['transformers.testing_utils'] = patched_transformers_testing_utils
Empty file.
Empty file.
Loading
Loading