Skip to content

Commit a0066d4

Browse files
Fix transformers rtn layer-wise quant (#2008)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 802a5af commit a0066d4

File tree

7 files changed

+78
-18
lines changed

7 files changed

+78
-18
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,18 @@ Pytorch and Intel-extension-for-pytorch version for intel GPU > 2.1 are required
116116
```bash
117117
pip install -r requirements_GPU.txt
118118
pip install transformers==4.38.1 # llama use 4.38.1
119-
source /opt/intel/oneapi/setvars.sh
120119
git clone https://github.com/intel/intel-extension-for-pytorch.git ipex-gpu
121120
cd ipex-gpu
122121
git submodule update --init --recursive
123122
export USE_AOT_DEVLIST='pvc,ats-m150'
124123
export BUILD_WITH_CPU=OFF
125124

125+
export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/:$LD_LIBRARY_PATH
126+
export OCL_ICD_VENDORS=/etc/OpenCL/vendors
127+
export CCL_ROOT=${CONDA_PREFIX}
128+
source /opt/intel/oneapi/setvars.sh --force
129+
export LLM_ACC_TEST=1
130+
126131
python setup.py install
127132
```
128133

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@
200200
tokenizer.save_pretrained(args.output_dir)
201201

202202
enable_optimize_transformers = False
203-
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen"]
203+
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen", "phi3"]
204204

205205
if config.model_type in opt_gpu_model_type_list:
206206
enable_optimize_transformers = True

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,16 @@ def convert(
130130

131131
if use_layer_wise:
132132
from neural_compressor.common.utils import DEFAULT_WORKSPACE
133-
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module, register_weight_hooks
133+
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module
134134

135135
if model_path == "":
136136
model_path = model.path
137137
assert model_path, "model_path should not be None."
138138
model_path = get_path(model_path)
139139

140-
register_weight_hooks(model, model_path, device=device, clean_weight=True)
141-
142140
for name, m in model.named_modules():
143-
141+
if use_layer_wise and len(list(m.named_children())) == 0:
142+
load_module(model, name, model_path, device=device)
144143
if not isinstance(m, supported_layers):
145144
continue
146145
if name in weight_config: # pragma: no cover
@@ -192,9 +191,6 @@ def convert(
192191
logger.debug(f"RTN quantized module:{name, m}")
193192
logger.debug(log_msg)
194193

195-
if use_layer_wise:
196-
load_module(model, name, model_path, device=device)
197-
198194
# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
199195
if is_transformers_imported():
200196
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))

neural_compressor/torch/utils/utility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,11 @@ def load_empty_model(pretrained_model_name_or_path, cls=None, **kwargs):
331331
if cls.__base__ == _BaseAutoModelClass:
332332
config = AutoConfig.from_pretrained(path, **kwargs)
333333
with init_empty_weights():
334-
model = cls.from_config(config)
334+
model = cls.from_config(config, **kwargs)
335335
else: # pragma: no cover
336336
config = cls.config_class.from_pretrained(path, **kwargs)
337337
with init_empty_weights():
338-
model = cls(config)
338+
model = cls(config, **kwargs)
339339
model.tie_weights()
340340
model.eval()
341341
model.path = pretrained_model_name_or_path

neural_compressor/transformers/models/modeling_auto.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
134134
(RtnConfig, AwqConfig, TeqConfig, GPTQConfig, AutoRoundConfig),
135135
):
136136
logger.info("Applying Weight Only Quantization.")
137-
if use_xpu:
137+
# set use_layer_wise on client
138+
if hasattr(quantization_config, "use_layer_wise"):
139+
import neural_compressor.torch.utils as torch_utils
140+
141+
process_type = torch_utils.get_processor_type_from_user_config()
142+
if process_type == torch_utils.ProcessorType.Client:
143+
quantization_config.use_layer_wise = True
144+
145+
if hasattr(quantization_config, "use_layer_wise") and quantization_config.use_layer_wise:
146+
from transformers.dynamic_module_utils import resolve_trust_remote_code
147+
148+
from neural_compressor.torch import load_empty_model
149+
150+
trust_remote_code = kwargs.get("trust_remote_code", None)
151+
has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map
152+
has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys()
153+
trust_remote_code = resolve_trust_remote_code(
154+
trust_remote_code,
155+
pretrained_model_name_or_path,
156+
has_local_code,
157+
has_remote_code,
158+
)
159+
160+
model = load_empty_model(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
161+
if use_cpu:
162+
quantization_config.post_init_cpu()
163+
elif use_xpu:
138164
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
139165
kwargs["low_cpu_mem_usage"] = True
140166
kwargs["device_map"] = "cpu"

neural_compressor/transformers/quantization/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def _replace_linear(
153153
"fp16": ipex.quantization.WoqLowpMode.FP16,
154154
"int8": ipex.quantization.WoqLowpMode.INT8,
155155
}
156-
157156
ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping(
158157
weight_dtype=weight_dtype[quantization_config.bits],
159158
lowp_mode=compute_dtype[quantization_config.compute_dtype],
@@ -366,11 +365,6 @@ def convert_to_quantized_model(model, config, device="cpu"):
366365

367366
# mapping to INC config
368367
dtype = "int4" if config.weight_dtype == "int4_fullrange" else config.weight_dtype
369-
import neural_compressor.torch.utils as torch_utils
370-
371-
process_type = torch_utils.get_processor_type_from_user_config()
372-
if process_type == torch_utils.ProcessorType.Client:
373-
config.use_layer_wise = True
374368
if config.quant_method.value == "rtn":
375369
quant_config = RTNConfig(
376370
dtype=dtype,
@@ -529,6 +523,12 @@ def convert_to_quantized_model(model, config, device="cpu"):
529523
if orig_dtype != torch.float32:
530524
q_model.to(dtype=orig_dtype)
531525

526+
if config.use_layer_wise and not (q_model.device == device or q_model.device.type == device):
527+
logger.warning(
528+
"Do not convert device to avoid out of memory. Recommend using saved quantized model to inference."
529+
)
530+
return q_model
531+
532532
return q_model.to(device)
533533

534534

test/3x/torch/quantization/weight_only/test_transfomers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,39 @@ def test_save_load(self):
115115
loaded_output = loaded_model(dummy_input)[0]
116116
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."
117117

118+
def test_use_layer_wise(self):
119+
model_name_or_path = self.model_name_or_path
120+
121+
fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
122+
dummy_input = fp32_model.dummy_inputs["input_ids"]
123+
124+
# RTN
125+
# use_layer_wise=True
126+
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True)
127+
woq_model = AutoModelForCausalLM.from_pretrained(
128+
model_name_or_path,
129+
quantization_config=woq_config,
130+
)
131+
woq_output = woq_model(dummy_input)[0]
132+
133+
# save
134+
output_dir = "./transformers_tmp"
135+
woq_model.save_pretrained(output_dir)
136+
137+
# load
138+
loaded_model = AutoModelForCausalLM.from_pretrained(output_dir)
139+
loaded_output = loaded_model(dummy_input)[0]
140+
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."
141+
142+
# use_layer_wise=False
143+
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=False)
144+
woq_model = AutoModelForCausalLM.from_pretrained(
145+
model_name_or_path,
146+
quantization_config=woq_config,
147+
)
148+
woq_output2 = woq_model(dummy_input)[0]
149+
assert torch.equal(woq_output, woq_output2), "use_layer_wise output should be same. Please double check."
150+
118151
def test_loading_autoawq_model(self):
119152
user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model)
120153
tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model)

0 commit comments

Comments
 (0)