Skip to content

Commit 50884f3

Browse files
committed
enable lwq for onnxrt woq
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent 3882e9c commit 50884f3

File tree

11 files changed

+600
-77
lines changed

11 files changed

+600
-77
lines changed

neural_compressor/onnxrt/algorithms/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,12 @@
1717
from neural_compressor.onnxrt.algorithms.weight_only.rtn import apply_rtn_on_model
1818
from neural_compressor.onnxrt.algorithms.weight_only.gptq import apply_gptq_on_model
1919
from neural_compressor.onnxrt.algorithms.weight_only.awq import apply_awq_on_model
20+
from neural_compressor.onnxrt.algorithms.layer_wise import layer_wise_quant
2021

21-
__all__ = ["Smoother", "apply_rtn_on_model", "apply_gptq_on_model", "apply_awq_on_model"]
22+
__all__ = [
23+
"Smoother",
24+
"apply_rtn_on_model",
25+
"apply_gptq_on_model",
26+
"apply_awq_on_model",
27+
"layer_wise_quant"
28+
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from neural_compressor.onnxrt.algorithms.layer_wise.core import layer_wise_quant
16+
17+
__all__ = ["layer_wise_quant"]
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 MIT HAN Lab
5+
# This source code is licensed under the MIT license
6+
#
7+
# Copyright (c) 2023 Intel Corporation
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
import os
22+
from copy import deepcopy
23+
from pathlib import Path
24+
from typing import Union, Callable, List
25+
26+
import onnx
27+
import onnxruntime as ort
28+
29+
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
30+
from neural_compressor.onnxrt.utils.onnx_model import ONNXModel
31+
from neural_compressor.onnxrt.utils.utility import check_model_with_infer_shapes
32+
from neural_compressor.common import Logger
33+
34+
logger = Logger().get_logger()
35+
36+
__all__ = [
37+
"layer_wise_quant",
38+
]
39+
40+
def layer_wise_quant(
41+
model: Union[onnx.ModelProto, ONNXModel, Path, str],
42+
quant_func: Callable,
43+
weight_config: dict,
44+
data_reader: CalibrationDataReader = None,
45+
*args,
46+
**kwargs
47+
) -> ONNXModel:
48+
"""Quantize model layer by layer to save memory.
49+
50+
Args:
51+
model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model.
52+
quant_func (Callable): quantization algo function.
53+
weight_config (dict): quantization config.
54+
data_reader (CalibrationDataReader, optional): data_reader for calibration. Defaults to None.
55+
56+
Returns:
57+
_type_: _description_
58+
"""
59+
# check whether model shape is inferred
60+
if not check_model_with_infer_shapes(model):
61+
logger.error(
62+
"Before applying layer-wise quantization, please make sure to "
63+
"run symbolic shape inference on your model like follows:\n"
64+
"import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n"
65+
"model = onnx.load(your_model_path)\n"
66+
"out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n"
67+
"onnx.save(out, infer_shape_model_path)\n"
68+
)
69+
raise ValueError("Fail to run layer-wise quantization.")
70+
71+
if not isinstance(model, ONNXModel):
72+
model = ONNXModel(model, ignore_warning=True, load_external_data=False)
73+
74+
origin_model = deepcopy(model)
75+
76+
providers = kwargs.get("providers", ["CPUExecutionProvider"])
77+
78+
# get and check split nodes
79+
split_nodes = origin_model.find_split_nodes()
80+
if len(split_nodes) == 0:
81+
logger.error(
82+
"Can't find split nodes for layer-wise quantization. "
83+
"We recommend applying graph optimization for your model like follows: \n"
84+
"import onnxruntime as ort \n"
85+
"sess_options = ort.SessionOptions() \n"
86+
"sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED "
87+
"# or ORT_ENABLE_BASIC \n"
88+
"sess_options.optimized_model_filepath = 'optimized_model_path' \n"
89+
"ort.InferenceSession(infer_shape_model_path, sess_options)"
90+
)
91+
raise ValueError("Fail to run layer-wise quantization.")
92+
logger.info(
93+
"Will split model into {} parts to do layer-wise quantization".format(
94+
len([node.name for node in split_nodes]) + 1
95+
)
96+
)
97+
logger.debug(
98+
"Will split model with these nodes for layer-wise quantization: {}".format(
99+
[node.name for node in split_nodes]
100+
)
101+
)
102+
103+
split_idx = 1
104+
model_to_split = [origin_model]
105+
quantized_model_merged = None
106+
107+
require_data_reader = data_reader is not None
108+
if require_data_reader:
109+
lwq_data_reader = [data_reader]
110+
111+
while len(model_to_split) != 0:
112+
# prepare model, node and data_reader for current split
113+
split_model = model_to_split.pop(0)
114+
split_node = split_nodes.pop(0)
115+
if require_data_reader:
116+
current_data_reader = lwq_data_reader.pop(0)
117+
118+
# if no remaining split nodes, it means this is the last split, and the two split models will be saved.
119+
save_both_split_models = True if len(split_nodes) == 0 else False
120+
121+
# split model with given split node
122+
split_model_part_1, split_model_part_2 = split_model.split_model_with_node(
123+
split_node.name, model.model_path, save_both_split_models
124+
)
125+
if not save_both_split_models:
126+
# append split_model_part_2 to do next split
127+
model_to_split.append(split_model_part_2)
128+
129+
logger.info("Quantize split model {}".format(split_idx))
130+
if require_data_reader:
131+
# process data_reader for current split and next split
132+
current_data_reader = _filter_data_reader_for_current_split_model(split_model_part_1.model, current_data_reader)
133+
next_data_reader = _prepare_data_reader_for_next_split_model(split_model_part_1.model_path, current_data_reader, providers)
134+
lwq_data_reader.append(next_data_reader)
135+
136+
# perform quantization
137+
split_model_part_1_quantized = quant_func(
138+
split_model_part_1,
139+
weight_config=weight_config,
140+
data_reader=current_data_reader,
141+
return_modelproto=False,
142+
**kwargs
143+
)
144+
else:
145+
# perform quantization
146+
split_model_part_1_quantized = quant_func(
147+
split_model_part_1,
148+
weight_config=weight_config,
149+
return_modelproto=False,
150+
**kwargs
151+
)
152+
153+
# check split model is valid
154+
try:
155+
ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers)
156+
except Exception as e:
157+
logger.error("Layer-wise quantized model {} can't be inferred correctly. "
158+
"Please check the raise exception".format(split_idx))
159+
raise e
160+
161+
# merge split quantized model
162+
if quantized_model_merged is None:
163+
quantized_model_merged = split_model_part_1_quantized
164+
quantized_model_merged.write_external_data_to_new_location(overwrite=True)
165+
else:
166+
quantized_model_merged.merge_split_models(split_model_part_1_quantized)
167+
168+
split_idx += 1
169+
# if this is the last split, quantize the last split model
170+
if save_both_split_models:
171+
logger.info("Quantize split model {}".format(split_idx))
172+
173+
# quantize split model
174+
if require_data_reader:
175+
# process data_reader for current split
176+
current_data_reader = lwq_data_reader.pop(0)
177+
current_data_reader = _filter_data_reader_for_current_split_model(split_model_part_2.model, current_data_reader)
178+
179+
# perform quantization
180+
split_model_part_2_quantized = quant_func(
181+
split_model_part_2,
182+
weight_config=weight_config,
183+
data_reader=current_data_reader,
184+
return_modelproto=False,
185+
**kwargs
186+
)
187+
else:
188+
# perform quantization
189+
split_model_part_2_quantized = quant_func(
190+
split_model_part_2,
191+
weight_config=weight_config,
192+
return_modelproto=False,
193+
**kwargs
194+
)
195+
196+
# check split model is valid
197+
try:
198+
ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers)
199+
except Exception as e:
200+
logger.error("Layer-wise quantized model {} can't be inferred correctly. "
201+
"Please check the raise exception".format(split_idx))
202+
raise e
203+
204+
# merge split quantized model
205+
if quantized_model_merged is None:
206+
quantized_model_merged = split_model_part_2_quantized
207+
quantized_model_merged.write_external_data_to_new_location(overwrite=True)
208+
else:
209+
quantized_model_merged.merge_split_models(split_model_part_2_quantized)
210+
211+
# reload external data to prevent external data file path errors
212+
from onnx.external_data_helper import load_external_data_for_model
213+
load_external_data_for_model(quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path))
214+
215+
return quantized_model_merged
216+
217+
218+
class DataReader(CalibrationDataReader):
219+
"""Data reader for layer-wise quantization."""
220+
221+
def __init__(self, data_list):
222+
self.data_list = data_list
223+
self.iter_next = iter(self.data_list)
224+
225+
def get_next(self):
226+
return next(self.iter_next, None)
227+
228+
def rewind(self):
229+
self.iter_next = iter(self.data_list)
230+
231+
232+
def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_reader: CalibrationDataReader):
233+
"""Filter data reader to remove data that is not in model input.
234+
235+
Args:
236+
model (onnx.ModelProto): onnx model.
237+
data_reader (CalibrationDataReader): data reader.
238+
239+
Returns:
240+
CalibrationDataReader: filtered data reader.
241+
"""
242+
filter_inputs = []
243+
input_names = [input.name for input in model.graph.input]
244+
while True:
245+
inputs = data_reader.get_next()
246+
if not inputs:
247+
break
248+
filter_input = {
249+
input_name: input_tensor
250+
for input_name, input_tensor in inputs.items()
251+
if input_name in input_names
252+
}
253+
filter_inputs.append(filter_input)
254+
return DataReader(filter_inputs)
255+
256+
def _prepare_data_reader_for_next_split_model(
257+
model_path: str,
258+
data_reader: CalibrationDataReader,
259+
providers: List[str] = ["CPUExecutionProvider"],
260+
):
261+
"""Prepare data reader for next split model.
262+
263+
Get data output of current split model and save for next split model.
264+
265+
Args:
266+
model (str): path to onnx model.
267+
data_reader (CalibrationDataReader): data reader
268+
providers (List[str], optional): providers to use. Defaults to ["CPUExecutionProvider"].
269+
270+
Returns:
271+
CalibrationDataReader: data reader for next split model.
272+
"""
273+
data_reader = deepcopy(data_reader)
274+
275+
data_reader_for_next_split_model = []
276+
session = ort.InferenceSession(model_path, providers=providers)
277+
output_names = [output.name for output in session.get_outputs()]
278+
while True:
279+
inputs = data_reader.get_next()
280+
if not inputs:
281+
break
282+
out = session.run(None, inputs)
283+
inputs.update({name: value for name, value in zip(output_names, out)})
284+
data_reader_for_next_split_model.append(inputs)
285+
return DataReader(data_reader_for_next_split_model)

neural_compressor/onnxrt/algorithms/weight_only/awq.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits,
275275

276276
def awq_quantize(
277277
model: Union[onnx.ModelProto, ONNXModel, Path, str],
278-
dataloader: CalibrationDataReader,
278+
data_reader: CalibrationDataReader,
279279
weight_config: dict = {},
280280
num_bits: int = 4,
281281
group_size: int = 32,
@@ -289,7 +289,7 @@ def awq_quantize(
289289
290290
Args:
291291
model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model.
292-
dataloader (CalibrationDataReader): dataloader for calibration.
292+
data_reader (CalibrationDataReader): data_reader for calibration.
293293
weight_config (dict, optional): quantization config
294294
For example,
295295
weight_config = {
@@ -420,7 +420,7 @@ def apply_awq_on_model(
420420
Args:
421421
model (Union[onnx.ModelProto, ONNXModel, Path, str]): nnx model.
422422
quant_config (dict): quantization config.
423-
calibration_data_reader (CalibrationDataReader): dataloader for calibration.
423+
calibration_data_reader (CalibrationDataReader): data_reader for calibration.
424424
425425
Returns:
426426
onnx.ModelProto: quantized onnx model.
@@ -434,4 +434,4 @@ def apply_awq_on_model(
434434
if isinstance(op_config, AWQConfig):
435435
quant_config[op_name_type] = op_config.to_dict()
436436

437-
return awq_quantize(model, dataloader=calibration_data_reader, weight_config=quant_config, **kwargs)
437+
return awq_quantize(model, data_reader=calibration_data_reader, weight_config=quant_config, **kwargs)

0 commit comments

Comments
 (0)