Skip to content

Commit 7e1fa90

Browse files
authored
Enable LWQ for onnxrt WOQ in 3.x API (#1625)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent 21cfeb8 commit 7e1fa90

File tree

12 files changed

+603
-80
lines changed

12 files changed

+603
-80
lines changed

neural_compressor/onnxrt/algorithms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717
from neural_compressor.onnxrt.algorithms.weight_only.rtn import apply_rtn_on_model
1818
from neural_compressor.onnxrt.algorithms.weight_only.gptq import apply_gptq_on_model
1919
from neural_compressor.onnxrt.algorithms.weight_only.awq import apply_awq_on_model
20+
from neural_compressor.onnxrt.algorithms.layer_wise import layer_wise_quant
2021

21-
__all__ = ["Smoother", "apply_rtn_on_model", "apply_gptq_on_model", "apply_awq_on_model"]
22+
__all__ = ["Smoother", "apply_rtn_on_model", "apply_gptq_on_model", "apply_awq_on_model", "layer_wise_quant"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from neural_compressor.onnxrt.algorithms.layer_wise.core import layer_wise_quant
16+
17+
__all__ = ["layer_wise_quant"]
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 MIT HAN Lab
5+
# This source code is licensed under the MIT license
6+
#
7+
# Copyright (c) 2023 Intel Corporation
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
import os
22+
from copy import deepcopy
23+
from pathlib import Path
24+
from typing import Callable, List, Union
25+
26+
import onnx
27+
import onnxruntime as ort
28+
import transformers
29+
from packaging.version import Version
30+
31+
from neural_compressor.common import Logger
32+
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
33+
from neural_compressor.onnxrt.utils.onnx_model import ONNXModel
34+
from neural_compressor.onnxrt.utils.utility import check_model_with_infer_shapes
35+
36+
logger = Logger().get_logger()
37+
38+
__all__ = [
39+
"layer_wise_quant",
40+
]
41+
42+
43+
def layer_wise_quant(
44+
model: Union[onnx.ModelProto, ONNXModel, Path, str],
45+
quant_func: Callable,
46+
weight_config: dict,
47+
data_reader: CalibrationDataReader = None,
48+
*args,
49+
**kwargs
50+
) -> ONNXModel:
51+
"""Quantize model layer by layer to save memory.
52+
53+
Args:
54+
model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model.
55+
quant_func (Callable): quantization algo function.
56+
weight_config (dict): quantization config.
57+
data_reader (CalibrationDataReader, optional): data_reader for calibration. Defaults to None.
58+
59+
Returns:
60+
_type_: _description_
61+
"""
62+
# TODO: remove the limitation for lwq
63+
if Version(transformers.__version__) > Version("4.37.2"):
64+
logger.warning(
65+
"Model (such as llama-2) exported with transformers {} may fail in layer-wise quant. "
66+
"we recommend downgrading transformers to 4.37.2 and try again.".format(transformers.__version__)
67+
)
68+
69+
# check whether model shape is inferred
70+
if not check_model_with_infer_shapes(model):
71+
logger.error(
72+
"Before applying layer-wise quantization, please make sure to "
73+
"run symbolic shape inference on your model like follows:\n"
74+
"import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n"
75+
"model = onnx.load(your_model_path)\n"
76+
"out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n"
77+
"onnx.save(out, infer_shape_model_path)\n"
78+
)
79+
raise ValueError("Fail to run layer-wise quantization.")
80+
81+
if not isinstance(model, ONNXModel):
82+
model = ONNXModel(model, ignore_warning=True, load_external_data=False)
83+
84+
origin_model = deepcopy(model)
85+
86+
providers = kwargs.get("providers", ["CPUExecutionProvider"])
87+
88+
# get and check split nodes
89+
split_nodes = origin_model.find_split_nodes()
90+
if len(split_nodes) == 0:
91+
logger.error(
92+
"Can't find split nodes for layer-wise quantization. "
93+
"We recommend applying graph optimization for your model like follows: \n"
94+
"import onnxruntime as ort \n"
95+
"sess_options = ort.SessionOptions() \n"
96+
"sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED "
97+
"# or ORT_ENABLE_BASIC \n"
98+
"sess_options.optimized_model_filepath = 'optimized_model_path' \n"
99+
"ort.InferenceSession(infer_shape_model_path, sess_options)"
100+
)
101+
raise ValueError("Fail to run layer-wise quantization.")
102+
logger.info(
103+
"Will split model into {} parts to do layer-wise quantization".format(
104+
len([node.name for node in split_nodes]) + 1
105+
)
106+
)
107+
logger.debug(
108+
"Will split model with these nodes for layer-wise quantization: {}".format([node.name for node in split_nodes])
109+
)
110+
111+
split_idx = 1
112+
model_to_split = [origin_model]
113+
quantized_model_merged = None
114+
115+
require_data_reader = data_reader is not None
116+
if require_data_reader:
117+
lwq_data_reader = [data_reader]
118+
119+
while len(model_to_split) != 0:
120+
# prepare model, node and data_reader for current split
121+
split_model = model_to_split.pop(0)
122+
split_node = split_nodes.pop(0)
123+
if require_data_reader:
124+
current_data_reader = lwq_data_reader.pop(0)
125+
126+
# if no remaining split nodes, it means this is the last split, and the two split models will be saved.
127+
save_both_split_models = True if len(split_nodes) == 0 else False
128+
129+
# split model with given split node
130+
split_model_part_1, split_model_part_2 = split_model.split_model_with_node(
131+
split_node.name, model.model_path, save_both_split_models
132+
)
133+
if not save_both_split_models:
134+
# append split_model_part_2 to do next split
135+
model_to_split.append(split_model_part_2)
136+
137+
logger.info("Quantize split model {}".format(split_idx))
138+
if require_data_reader:
139+
# process data_reader for current split and next split
140+
current_data_reader = _filter_data_reader_for_current_split_model(
141+
split_model_part_1.model, current_data_reader
142+
)
143+
next_data_reader = _prepare_data_reader_for_next_split_model(
144+
split_model_part_1.model_path, current_data_reader, providers
145+
)
146+
lwq_data_reader.append(next_data_reader)
147+
148+
# perform quantization
149+
split_model_part_1_quantized = quant_func(
150+
split_model_part_1,
151+
weight_config=weight_config,
152+
data_reader=current_data_reader,
153+
return_modelproto=False,
154+
**kwargs
155+
)
156+
else:
157+
# perform quantization
158+
split_model_part_1_quantized = quant_func(
159+
split_model_part_1, weight_config=weight_config, return_modelproto=False, **kwargs
160+
)
161+
162+
# check split model is valid
163+
try:
164+
ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers)
165+
except Exception as e:
166+
logger.error(
167+
"Layer-wise quantized model {} can't be inferred correctly. "
168+
"Please check the raise exception".format(split_idx)
169+
)
170+
raise e
171+
172+
# merge split quantized model
173+
if quantized_model_merged is None:
174+
quantized_model_merged = split_model_part_1_quantized
175+
quantized_model_merged.write_external_data_to_new_location(overwrite=True)
176+
else:
177+
quantized_model_merged.merge_split_models(split_model_part_1_quantized)
178+
179+
split_idx += 1
180+
# if this is the last split, quantize the last split model
181+
if save_both_split_models:
182+
logger.info("Quantize split model {}".format(split_idx))
183+
184+
# quantize split model
185+
if require_data_reader:
186+
# process data_reader for current split
187+
current_data_reader = lwq_data_reader.pop(0)
188+
current_data_reader = _filter_data_reader_for_current_split_model(
189+
split_model_part_2.model, current_data_reader
190+
)
191+
192+
# perform quantization
193+
split_model_part_2_quantized = quant_func(
194+
split_model_part_2,
195+
weight_config=weight_config,
196+
data_reader=current_data_reader,
197+
return_modelproto=False,
198+
**kwargs
199+
)
200+
else:
201+
# perform quantization
202+
split_model_part_2_quantized = quant_func(
203+
split_model_part_2, weight_config=weight_config, return_modelproto=False, **kwargs
204+
)
205+
206+
# check split model is valid
207+
try:
208+
ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers)
209+
except Exception as e:
210+
logger.error(
211+
"Layer-wise quantized model {} can't be inferred correctly. "
212+
"Please check the raise exception".format(split_idx)
213+
)
214+
raise e
215+
216+
# merge split quantized model
217+
if quantized_model_merged is None:
218+
quantized_model_merged = split_model_part_2_quantized
219+
quantized_model_merged.write_external_data_to_new_location(overwrite=True)
220+
else:
221+
quantized_model_merged.merge_split_models(split_model_part_2_quantized)
222+
223+
# reload external data to prevent external data file path errors
224+
from onnx.external_data_helper import load_external_data_for_model
225+
226+
load_external_data_for_model(quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path))
227+
228+
return quantized_model_merged
229+
230+
231+
class DataReader(CalibrationDataReader):
232+
"""Data reader for layer-wise quantization."""
233+
234+
def __init__(self, data_list):
235+
self.data_list = data_list
236+
self.iter_next = iter(self.data_list)
237+
238+
def get_next(self):
239+
return next(self.iter_next, None)
240+
241+
def rewind(self):
242+
self.iter_next = iter(self.data_list)
243+
244+
245+
def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_reader: CalibrationDataReader):
246+
"""Filter data reader to remove data that is not in model input.
247+
248+
Args:
249+
model (onnx.ModelProto): onnx model.
250+
data_reader (CalibrationDataReader): data reader.
251+
252+
Returns:
253+
CalibrationDataReader: filtered data reader.
254+
"""
255+
filter_inputs = []
256+
input_names = [input.name for input in model.graph.input]
257+
while True:
258+
inputs = data_reader.get_next()
259+
if not inputs:
260+
break
261+
filter_input = {
262+
input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names
263+
}
264+
filter_inputs.append(filter_input)
265+
return DataReader(filter_inputs)
266+
267+
268+
def _prepare_data_reader_for_next_split_model(
269+
model_path: str,
270+
data_reader: CalibrationDataReader,
271+
providers: List[str] = ["CPUExecutionProvider"],
272+
):
273+
"""Prepare data reader for next split model.
274+
275+
Get data output of current split model and save for next split model.
276+
277+
Args:
278+
model (str): path to onnx model.
279+
data_reader (CalibrationDataReader): data reader
280+
providers (List[str], optional): providers to use. Defaults to ["CPUExecutionProvider"].
281+
282+
Returns:
283+
CalibrationDataReader: data reader for next split model.
284+
"""
285+
data_reader = deepcopy(data_reader)
286+
287+
data_reader_for_next_split_model = []
288+
session = ort.InferenceSession(model_path, providers=providers)
289+
output_names = [output.name for output in session.get_outputs()]
290+
while True:
291+
inputs = data_reader.get_next()
292+
if not inputs:
293+
break
294+
out = session.run(None, inputs)
295+
inputs.update({name: value for name, value in zip(output_names, out)})
296+
data_reader_for_next_split_model.append(inputs)
297+
return DataReader(data_reader_for_next_split_model)

neural_compressor/onnxrt/algorithms/weight_only/awq.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits,
275275

276276
def awq_quantize(
277277
model: Union[onnx.ModelProto, ONNXModel, Path, str],
278-
dataloader: CalibrationDataReader,
278+
data_reader: CalibrationDataReader,
279279
weight_config: dict = {},
280280
num_bits: int = 4,
281281
group_size: int = 32,
@@ -289,7 +289,7 @@ def awq_quantize(
289289
290290
Args:
291291
model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model.
292-
dataloader (CalibrationDataReader): dataloader for calibration.
292+
data_reader (CalibrationDataReader): data_reader for calibration.
293293
weight_config (dict, optional): quantization config
294294
For example,
295295
weight_config = {
@@ -323,8 +323,8 @@ def awq_quantize(
323323
full_ratio = {}
324324

325325
if enable_mse_search:
326-
inputs, so = prepare_inputs(model, dataloader, providers)
327-
del dataloader
326+
inputs, so = prepare_inputs(model, data_reader, providers)
327+
del data_reader
328328

329329
org_output = copy.deepcopy(model.model.graph.output)
330330
model.remove_tensors_from_outputs([i.name for i in org_output])
@@ -420,7 +420,7 @@ def apply_awq_on_model(
420420
Args:
421421
model (Union[onnx.ModelProto, ONNXModel, Path, str]): nnx model.
422422
quant_config (dict): quantization config.
423-
calibration_data_reader (CalibrationDataReader): dataloader for calibration.
423+
calibration_data_reader (CalibrationDataReader): data_reader for calibration.
424424
425425
Returns:
426426
onnx.ModelProto: quantized onnx model.
@@ -434,4 +434,4 @@ def apply_awq_on_model(
434434
if isinstance(op_config, AWQConfig):
435435
quant_config[op_name_type] = op_config.to_dict()
436436

437-
return awq_quantize(model, dataloader=calibration_data_reader, weight_config=quant_config, **kwargs)
437+
return awq_quantize(model, data_reader=calibration_data_reader, weight_config=quant_config, **kwargs)

0 commit comments

Comments
 (0)