Skip to content

Commit 35181db

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4fd03c2 commit 35181db

File tree

7 files changed

+46
-48
lines changed

7 files changed

+46
-48
lines changed

neural_compressor/onnxrt/algorithms/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,4 @@
1919
from neural_compressor.onnxrt.algorithms.weight_only.awq import apply_awq_on_model
2020
from neural_compressor.onnxrt.algorithms.layer_wise import layer_wise_quant
2121

22-
__all__ = [
23-
"Smoother",
24-
"apply_rtn_on_model",
25-
"apply_gptq_on_model",
26-
"apply_awq_on_model",
27-
"layer_wise_quant"
28-
]
22+
__all__ = ["Smoother", "apply_rtn_on_model", "apply_gptq_on_model", "apply_awq_on_model", "layer_wise_quant"]

neural_compressor/onnxrt/algorithms/layer_wise/core.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,23 @@
2121
import os
2222
from copy import deepcopy
2323
from pathlib import Path
24-
from typing import Union, Callable, List
24+
from typing import Callable, List, Union
2525

2626
import onnx
2727
import onnxruntime as ort
2828

29+
from neural_compressor.common import Logger
2930
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
3031
from neural_compressor.onnxrt.utils.onnx_model import ONNXModel
3132
from neural_compressor.onnxrt.utils.utility import check_model_with_infer_shapes
32-
from neural_compressor.common import Logger
3333

3434
logger = Logger().get_logger()
3535

3636
__all__ = [
3737
"layer_wise_quant",
3838
]
3939

40+
4041
def layer_wise_quant(
4142
model: Union[onnx.ModelProto, ONNXModel, Path, str],
4243
quant_func: Callable,
@@ -95,9 +96,7 @@ def layer_wise_quant(
9596
)
9697
)
9798
logger.debug(
98-
"Will split model with these nodes for layer-wise quantization: {}".format(
99-
[node.name for node in split_nodes]
100-
)
99+
"Will split model with these nodes for layer-wise quantization: {}".format([node.name for node in split_nodes])
101100
)
102101

103102
split_idx = 1
@@ -129,8 +128,12 @@ def layer_wise_quant(
129128
logger.info("Quantize split model {}".format(split_idx))
130129
if require_data_reader:
131130
# process data_reader for current split and next split
132-
current_data_reader = _filter_data_reader_for_current_split_model(split_model_part_1.model, current_data_reader)
133-
next_data_reader = _prepare_data_reader_for_next_split_model(split_model_part_1.model_path, current_data_reader, providers)
131+
current_data_reader = _filter_data_reader_for_current_split_model(
132+
split_model_part_1.model, current_data_reader
133+
)
134+
next_data_reader = _prepare_data_reader_for_next_split_model(
135+
split_model_part_1.model_path, current_data_reader, providers
136+
)
134137
lwq_data_reader.append(next_data_reader)
135138

136139
# perform quantization
@@ -144,18 +147,17 @@ def layer_wise_quant(
144147
else:
145148
# perform quantization
146149
split_model_part_1_quantized = quant_func(
147-
split_model_part_1,
148-
weight_config=weight_config,
149-
return_modelproto=False,
150-
**kwargs
150+
split_model_part_1, weight_config=weight_config, return_modelproto=False, **kwargs
151151
)
152152

153153
# check split model is valid
154154
try:
155155
ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers)
156156
except Exception as e:
157-
logger.error("Layer-wise quantized model {} can't be inferred correctly. "
158-
"Please check the raise exception".format(split_idx))
157+
logger.error(
158+
"Layer-wise quantized model {} can't be inferred correctly. "
159+
"Please check the raise exception".format(split_idx)
160+
)
159161
raise e
160162

161163
# merge split quantized model
@@ -174,7 +176,9 @@ def layer_wise_quant(
174176
if require_data_reader:
175177
# process data_reader for current split
176178
current_data_reader = lwq_data_reader.pop(0)
177-
current_data_reader = _filter_data_reader_for_current_split_model(split_model_part_2.model, current_data_reader)
179+
current_data_reader = _filter_data_reader_for_current_split_model(
180+
split_model_part_2.model, current_data_reader
181+
)
178182

179183
# perform quantization
180184
split_model_part_2_quantized = quant_func(
@@ -187,18 +191,17 @@ def layer_wise_quant(
187191
else:
188192
# perform quantization
189193
split_model_part_2_quantized = quant_func(
190-
split_model_part_2,
191-
weight_config=weight_config,
192-
return_modelproto=False,
193-
**kwargs
194+
split_model_part_2, weight_config=weight_config, return_modelproto=False, **kwargs
194195
)
195196

196197
# check split model is valid
197198
try:
198199
ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers)
199200
except Exception as e:
200-
logger.error("Layer-wise quantized model {} can't be inferred correctly. "
201-
"Please check the raise exception".format(split_idx))
201+
logger.error(
202+
"Layer-wise quantized model {} can't be inferred correctly. "
203+
"Please check the raise exception".format(split_idx)
204+
)
202205
raise e
203206

204207
# merge split quantized model
@@ -210,6 +213,7 @@ def layer_wise_quant(
210213

211214
# reload external data to prevent external data file path errors
212215
from onnx.external_data_helper import load_external_data_for_model
216+
213217
load_external_data_for_model(quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path))
214218

215219
return quantized_model_merged
@@ -246,13 +250,12 @@ def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_rea
246250
if not inputs:
247251
break
248252
filter_input = {
249-
input_name: input_tensor
250-
for input_name, input_tensor in inputs.items()
251-
if input_name in input_names
253+
input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names
252254
}
253255
filter_inputs.append(filter_input)
254256
return DataReader(filter_inputs)
255257

258+
256259
def _prepare_data_reader_for_next_split_model(
257260
model_path: str,
258261
data_reader: CalibrationDataReader,

neural_compressor/onnxrt/algorithms/weight_only/gptq.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,13 +439,12 @@ def apply_gptq_on_model(
439439
quant_func=gptq_quantize,
440440
weight_config=quant_config,
441441
data_reader=calibration_data_reader,
442-
**quant_kwargs)
442+
**quant_kwargs
443+
)
443444
else:
444445
quantized_model = gptq_quantize(
445-
model,
446-
data_reader=calibration_data_reader,
447-
weight_config=quant_config,
448-
**quant_kwargs)
446+
model, data_reader=calibration_data_reader, weight_config=quant_config, **quant_kwargs
447+
)
449448

450449
if isinstance(quantized_model, ONNXModel):
451450
quantized_model = quantized_model.model

neural_compressor/onnxrt/algorithms/weight_only/rtn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,9 @@ def apply_rtn_on_model(model: Union[onnx.ModelProto, ONNXModel, Path, str], quan
213213
if layer_wise:
214214
from neural_compressor.onnxrt.algorithms import layer_wise_quant
215215

216-
quantized_model = layer_wise_quant(
217-
model, quant_func=rtn_quantize, weight_config=quant_config, **quant_kwargs)
216+
quantized_model = layer_wise_quant(model, quant_func=rtn_quantize, weight_config=quant_config, **quant_kwargs)
218217
else:
219-
quantized_model = rtn_quantize(
220-
model, weight_config=quant_config, **quant_kwargs)
218+
quantized_model = rtn_quantize(model, weight_config=quant_config, **quant_kwargs)
221219

222220
if isinstance(quantized_model, ONNXModel):
223221
quantized_model = quantized_model.model

neural_compressor/onnxrt/utils/onnx_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def model_path(self, path):
7474
def check_is_large_model(self):
7575
"""Check model > 2GB."""
7676
from neural_compressor.onnxrt.utils.utility import MAXIMUM_PROTOBUF
77+
7778
init_size = 0
7879
for init in self.model.graph.initializer:
7980
# if initializer has external data location, return True
@@ -417,7 +418,9 @@ def topological_sort(self, enable_subgraph=False):
417418
def get_nodes_chain(self, start, stop, result_chain=[]):
418419
"""Get nodes chain with given start node and stop node."""
419420
from collections import deque
421+
420422
from onnx import NodeProto
423+
421424
from neural_compressor.onnxrt.utils.utility import find_by_name
422425

423426
# process start node list
@@ -818,9 +821,7 @@ def find_split_nodes(self):
818821
split_nodes = self.find_split_node_for_layer_wise_quantization()
819822
return split_nodes
820823

821-
def split_model_with_node(
822-
self, split_node_name, path_of_model_to_split, save_both_split_models=True
823-
):
824+
def split_model_with_node(self, split_node_name, path_of_model_to_split, save_both_split_models=True):
824825
"""Split model into two parts at a given node.
825826
826827
Args:

neural_compressor/onnxrt/utils/utility.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
import numpy as np
1919
import onnx
20-
from packaging.version import Version
2120
import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer
21+
from packaging.version import Version
2222

2323
from neural_compressor.common import Logger
2424

@@ -274,6 +274,7 @@ def quantize_data(data, quantize_range, qType, scheme):
274274
quantized_data = _quantize_data_with_scale_zero(data, qType, scheme, scale, zero_point)
275275
return rmin, rmax, zero_point, scale, quantized_data
276276

277+
277278
def check_model_with_infer_shapes(model):
278279
"""Check if the model has been shape inferred."""
279280
from neural_compressor.onnxrt.utils.onnx_model import ONNXModel

test/3x/onnxrt/quantization/layer_wise/test_layer_wise.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import os
2-
import torch
32
import shutil
43
import unittest
54
from copy import deepcopy
6-
from transformers import AutoTokenizer
75

86
import onnx
9-
from optimum.exporters.onnx import main_export
107
import onnxruntime as ort
118
import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer
9+
import torch
10+
from optimum.exporters.onnx import main_export
11+
from transformers import AutoTokenizer
1212

13-
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
1413
from neural_compressor.common import Logger
14+
from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader
1515

1616
logger = Logger().get_logger()
1717

@@ -24,6 +24,7 @@ def find_onnx_file(folder_path):
2424
return os.path.join(root, file)
2525
return None
2626

27+
2728
class DummyNLPDataloader(CalibrationDataReader):
2829
def __init__(self, model_name):
2930
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -51,6 +52,7 @@ def get_next(self):
5152
def rewind(self):
5253
self.iter_next = iter(self.encoded_list)
5354

55+
5456
class TestLayerWiseQuant(unittest.TestCase):
5557
@classmethod
5658
def setUpClass(self):
@@ -60,7 +62,7 @@ def setUpClass(self):
6062

6163
model = onnx.load(model_path)
6264
model = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)
63-
infer_shape_model_path = 'llama-2-tiny/model-infer-shape.onnx'
65+
infer_shape_model_path = "llama-2-tiny/model-infer-shape.onnx"
6466
onnx.save(model, infer_shape_model_path)
6567

6668
sess_options = ort.SessionOptions()

0 commit comments

Comments
 (0)