Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables CPU AWQ model with IPEX version. #33460

Merged
merged 13 commits into from
Oct 4, 2024
40 changes: 40 additions & 0 deletions docs/source/en/quantization/awq.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,43 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
Note this feature is supported on AMD GPUs.

</Tip>


## CPU support

Recent versions of `autoawq` supports CPU with ipex op optimizations. To get started, first install the latest version of `autoawq` by running:

```bash
pip install intel-extension-for-pytorch
pip install git+https://github.com/casper-hansen/AutoAWQ.git
```

Get started by passing an `AwqConfig()` with `version="ipex"`.

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig

quantization_config = AwqConfig(version="ipex")

model = AutoModelForCausalLM.from_pretrained(
"TheBloke/WizardLM-1.0-Uncensored-Llama2-13B-AWQ",
quantization_config=quantization_config,
device_map="cpu",
)

input_ids = torch.randint(0, 100, (1, 128), dtype=torch.long, device="cpu")
output = model(input_ids)
print(output.logits)

tokenizer = AutoTokenizer.from_pretrained("TheBloke/WizardLM-1.0-Uncensored-Llama2-13B-AWQ")
input_ids = tokenizer.encode("How to make a cake", return_tensors="pt").to(model.device)
output = model.generate(input_ids, do_sample=True, max_length=50, pad_token_id=50256)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

<Tip warning={true}>

Note this feature is supported on Intel CPUs.

</Tip>
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"post_init_awq_ipex_modules",
"replace_quantization_scales",
"replace_with_awq_linear",
],
Expand Down Expand Up @@ -115,6 +116,7 @@
from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
post_init_awq_ipex_modules,
replace_quantization_scales,
replace_with_awq_linear,
)
Expand Down
27 changes: 24 additions & 3 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def replace_with_awq_linear(
target_cls = WQLinear_ExllamaV2
else:
raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
elif quantization_config.version == AWQLinearVersion.IPEX:
from awq.modules.linear.gemm_ipex import WQLinear_IPEX

target_cls = WQLinear_IPEX
else:
raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
else:
Expand Down Expand Up @@ -266,8 +270,9 @@ def fuse_awq_modules(model, quantization_config):
# Replace layer norms
_fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)

# Replace MLP layers
_fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)
# Replace MLP layers if awq version is not ipex.
if quantization_config.version != "ipex":
_fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)

# Replace attention layers
attention_has_been_fused = _fuse_awq_attention_layers(
Expand Down Expand Up @@ -372,7 +377,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
The `QuantAttentionFused` class as it only supports that class
for now.
"""
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_IPEX

module_has_been_fused = False

Expand All @@ -389,6 +394,9 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
elif isinstance(q_proj, WQLinear_GEMM):
linear_target_cls = WQLinear_GEMM
cat_dim = 1
elif isinstance(q_proj, WQLinear_IPEX):
linear_target_cls = WQLinear_IPEX
cat_dim = 1
else:
raise ValueError("Unsupported q_proj type: {type(q_proj)}")

Expand Down Expand Up @@ -466,3 +474,16 @@ def post_init_awq_exllama_modules(model, exllama_config):
raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")

return model


def post_init_awq_ipex_modules(model):
"""
Runs post init for IPEX layers which performs:
- Weights packing, reordering and repacking
"""

from awq.modules.linear.gemm_ipex import ipex_post_init

model = ipex_post_init(model)

return model
8 changes: 5 additions & 3 deletions src/transformers/quantizers/quantizer_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run AWQ quantized model.")

if not is_auto_awq_available():
raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")

Expand Down Expand Up @@ -106,6 +103,11 @@ def _process_model_after_weight_loading(self, model):

model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)

if self.quantization_config.version == AWQLinearVersion.IPEX:
from ..integrations import post_init_awq_ipex_modules

model = post_init_awq_ipex_modules(model)

@property
def is_serializable(self):
# AWQ through auto-awq has been always serializable, except if the model is fused.
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class AWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
EXLLAMA = "exllama"
IPEX = "ipex"

@staticmethod
def from_str(version: str):
Expand All @@ -60,6 +61,8 @@ def from_str(version: str):
return AWQLinearVersion.GEMV
elif version == "exllama":
return AWQLinearVersion.EXLLAMA
elif version == "ipex":
return AWQLinearVersion.IPEX
else:
raise ValueError(f"Unknown AWQLinearVersion {version}")

Expand Down Expand Up @@ -818,18 +821,20 @@ def post_init(self):
r"""
Safety checker that arguments are correct
"""
if not torch.cuda.is_available():
raise ValueError("AWQ is only available on GPU")

if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]:
raise ValueError(
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
)

self.version = AWQLinearVersion.from_str(self.version)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
if self.version not in [
AWQLinearVersion.GEMM,
AWQLinearVersion.GEMV,
AWQLinearVersion.EXLLAMA,
AWQLinearVersion.IPEX,
]:
raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA, AWQLinearVersion.IPEX] - not recognized version {self.version}"
)

if self.backend == AwqBackendPackingMethod.LLMAWQ:
Expand Down
26 changes: 26 additions & 0 deletions tests/quantization/autoawq/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from transformers.testing_utils import (
require_accelerate,
require_auto_awq,
require_intel_extension_for_pytorch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
Expand Down Expand Up @@ -490,3 +491,28 @@ def test_load_quantized_model(self):
"TechxGenus/starcoder2-3b-AWQ", torch_dtype=torch.float16, device_map="cuda"
)
self.assertTrue(isinstance(quantized_model.model.layers[0].mlp.act, ScaledActivation))


@slow
@require_auto_awq
@require_accelerate
@require_intel_extension_for_pytorch
class AwqIPEXTest(unittest.TestCase):
def test_quantized_model_ipex(self):
"""
Simple test that checks if the quantized model is working properly with ipex backend
"""
quantization_config = AwqConfig(version="ipex")

model = AutoModelForCausalLM.from_pretrained(
"TheBloke/WizardLM-1.0-Uncensored-Llama2-13B-AWQ",
quantization_config=quantization_config,
device_map="cpu",
)
tokenizer = AutoTokenizer.from_pretrained("TheBloke/WizardLM-1.0-Uncensored-Llama2-13B-AWQ")
input_ids = tokenizer.encode("How to make a cake", return_tensors="pt").to(model.device)
output = model.generate(input_ids, do_sample=False, max_length=20, pad_token_id=50256)
print(tokenizer.decode(output[0], skip_special_tokens=True))

expected_output = "How to make a cake with flour, sugar, eggs, and baking powder"
self.assertIn(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)