|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +# |
| 4 | +# Copyright (c) 2023 MIT HAN Lab |
| 5 | +# This source code is licensed under the MIT license |
| 6 | +# |
| 7 | +# Copyright (c) 2023 Intel Corporation |
| 8 | +# |
| 9 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 10 | +# you may not use this file except in compliance with the License. |
| 11 | +# You may obtain a copy of the License at |
| 12 | +# |
| 13 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 14 | +# |
| 15 | +# Unless required by applicable law or agreed to in writing, software |
| 16 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 17 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 18 | +# See the License for the specific language governing permissions and |
| 19 | +# limitations under the License. |
| 20 | + |
| 21 | +import os |
| 22 | +from copy import deepcopy |
| 23 | +from pathlib import Path |
| 24 | +from typing import Callable, List, Union |
| 25 | + |
| 26 | +import onnx |
| 27 | +import onnxruntime as ort |
| 28 | +import transformers |
| 29 | +from packaging.version import Version |
| 30 | + |
| 31 | +from neural_compressor.common import Logger |
| 32 | +from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader |
| 33 | +from neural_compressor.onnxrt.utils.onnx_model import ONNXModel |
| 34 | +from neural_compressor.onnxrt.utils.utility import check_model_with_infer_shapes |
| 35 | + |
| 36 | +logger = Logger().get_logger() |
| 37 | + |
| 38 | +__all__ = [ |
| 39 | + "layer_wise_quant", |
| 40 | +] |
| 41 | + |
| 42 | + |
| 43 | +def layer_wise_quant( |
| 44 | + model: Union[onnx.ModelProto, ONNXModel, Path, str], |
| 45 | + quant_func: Callable, |
| 46 | + weight_config: dict, |
| 47 | + data_reader: CalibrationDataReader = None, |
| 48 | + *args, |
| 49 | + **kwargs |
| 50 | +) -> ONNXModel: |
| 51 | + """Quantize model layer by layer to save memory. |
| 52 | +
|
| 53 | + Args: |
| 54 | + model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model. |
| 55 | + quant_func (Callable): quantization algo function. |
| 56 | + weight_config (dict): quantization config. |
| 57 | + data_reader (CalibrationDataReader, optional): data_reader for calibration. Defaults to None. |
| 58 | +
|
| 59 | + Returns: |
| 60 | + _type_: _description_ |
| 61 | + """ |
| 62 | + # TODO: remove the limitation for lwq |
| 63 | + if Version(transformers.__version__) > Version("4.37.2"): |
| 64 | + logger.warning( |
| 65 | + "Model (such as llama-2) exported with transformers {} may fail in layer-wise quant. " |
| 66 | + "we recommend downgrading transformers to 4.37.2 and try again.".format(transformers.__version__) |
| 67 | + ) |
| 68 | + |
| 69 | + # check whether model shape is inferred |
| 70 | + if not check_model_with_infer_shapes(model): |
| 71 | + logger.error( |
| 72 | + "Before applying layer-wise quantization, please make sure to " |
| 73 | + "run symbolic shape inference on your model like follows:\n" |
| 74 | + "import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n" |
| 75 | + "model = onnx.load(your_model_path)\n" |
| 76 | + "out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n" |
| 77 | + "onnx.save(out, infer_shape_model_path)\n" |
| 78 | + ) |
| 79 | + raise ValueError("Fail to run layer-wise quantization.") |
| 80 | + |
| 81 | + if not isinstance(model, ONNXModel): |
| 82 | + model = ONNXModel(model, ignore_warning=True, load_external_data=False) |
| 83 | + |
| 84 | + origin_model = deepcopy(model) |
| 85 | + |
| 86 | + providers = kwargs.get("providers", ["CPUExecutionProvider"]) |
| 87 | + |
| 88 | + # get and check split nodes |
| 89 | + split_nodes = origin_model.find_split_nodes() |
| 90 | + if len(split_nodes) == 0: |
| 91 | + logger.error( |
| 92 | + "Can't find split nodes for layer-wise quantization. " |
| 93 | + "We recommend applying graph optimization for your model like follows: \n" |
| 94 | + "import onnxruntime as ort \n" |
| 95 | + "sess_options = ort.SessionOptions() \n" |
| 96 | + "sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED " |
| 97 | + "# or ORT_ENABLE_BASIC \n" |
| 98 | + "sess_options.optimized_model_filepath = 'optimized_model_path' \n" |
| 99 | + "ort.InferenceSession(infer_shape_model_path, sess_options)" |
| 100 | + ) |
| 101 | + raise ValueError("Fail to run layer-wise quantization.") |
| 102 | + logger.info( |
| 103 | + "Will split model into {} parts to do layer-wise quantization".format( |
| 104 | + len([node.name for node in split_nodes]) + 1 |
| 105 | + ) |
| 106 | + ) |
| 107 | + logger.debug( |
| 108 | + "Will split model with these nodes for layer-wise quantization: {}".format([node.name for node in split_nodes]) |
| 109 | + ) |
| 110 | + |
| 111 | + split_idx = 1 |
| 112 | + model_to_split = [origin_model] |
| 113 | + quantized_model_merged = None |
| 114 | + |
| 115 | + require_data_reader = data_reader is not None |
| 116 | + if require_data_reader: |
| 117 | + lwq_data_reader = [data_reader] |
| 118 | + |
| 119 | + while len(model_to_split) != 0: |
| 120 | + # prepare model, node and data_reader for current split |
| 121 | + split_model = model_to_split.pop(0) |
| 122 | + split_node = split_nodes.pop(0) |
| 123 | + if require_data_reader: |
| 124 | + current_data_reader = lwq_data_reader.pop(0) |
| 125 | + |
| 126 | + # if no remaining split nodes, it means this is the last split, and the two split models will be saved. |
| 127 | + save_both_split_models = True if len(split_nodes) == 0 else False |
| 128 | + |
| 129 | + # split model with given split node |
| 130 | + split_model_part_1, split_model_part_2 = split_model.split_model_with_node( |
| 131 | + split_node.name, model.model_path, save_both_split_models |
| 132 | + ) |
| 133 | + if not save_both_split_models: |
| 134 | + # append split_model_part_2 to do next split |
| 135 | + model_to_split.append(split_model_part_2) |
| 136 | + |
| 137 | + logger.info("Quantize split model {}".format(split_idx)) |
| 138 | + if require_data_reader: |
| 139 | + # process data_reader for current split and next split |
| 140 | + current_data_reader = _filter_data_reader_for_current_split_model( |
| 141 | + split_model_part_1.model, current_data_reader |
| 142 | + ) |
| 143 | + next_data_reader = _prepare_data_reader_for_next_split_model( |
| 144 | + split_model_part_1.model_path, current_data_reader, providers |
| 145 | + ) |
| 146 | + lwq_data_reader.append(next_data_reader) |
| 147 | + |
| 148 | + # perform quantization |
| 149 | + split_model_part_1_quantized = quant_func( |
| 150 | + split_model_part_1, |
| 151 | + weight_config=weight_config, |
| 152 | + data_reader=current_data_reader, |
| 153 | + return_modelproto=False, |
| 154 | + **kwargs |
| 155 | + ) |
| 156 | + else: |
| 157 | + # perform quantization |
| 158 | + split_model_part_1_quantized = quant_func( |
| 159 | + split_model_part_1, weight_config=weight_config, return_modelproto=False, **kwargs |
| 160 | + ) |
| 161 | + |
| 162 | + # check split model is valid |
| 163 | + try: |
| 164 | + ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers) |
| 165 | + except Exception as e: |
| 166 | + logger.error( |
| 167 | + "Layer-wise quantized model {} can't be inferred correctly. " |
| 168 | + "Please check the raise exception".format(split_idx) |
| 169 | + ) |
| 170 | + raise e |
| 171 | + |
| 172 | + # merge split quantized model |
| 173 | + if quantized_model_merged is None: |
| 174 | + quantized_model_merged = split_model_part_1_quantized |
| 175 | + quantized_model_merged.write_external_data_to_new_location(overwrite=True) |
| 176 | + else: |
| 177 | + quantized_model_merged.merge_split_models(split_model_part_1_quantized) |
| 178 | + |
| 179 | + split_idx += 1 |
| 180 | + # if this is the last split, quantize the last split model |
| 181 | + if save_both_split_models: |
| 182 | + logger.info("Quantize split model {}".format(split_idx)) |
| 183 | + |
| 184 | + # quantize split model |
| 185 | + if require_data_reader: |
| 186 | + # process data_reader for current split |
| 187 | + current_data_reader = lwq_data_reader.pop(0) |
| 188 | + current_data_reader = _filter_data_reader_for_current_split_model( |
| 189 | + split_model_part_2.model, current_data_reader |
| 190 | + ) |
| 191 | + |
| 192 | + # perform quantization |
| 193 | + split_model_part_2_quantized = quant_func( |
| 194 | + split_model_part_2, |
| 195 | + weight_config=weight_config, |
| 196 | + data_reader=current_data_reader, |
| 197 | + return_modelproto=False, |
| 198 | + **kwargs |
| 199 | + ) |
| 200 | + else: |
| 201 | + # perform quantization |
| 202 | + split_model_part_2_quantized = quant_func( |
| 203 | + split_model_part_2, weight_config=weight_config, return_modelproto=False, **kwargs |
| 204 | + ) |
| 205 | + |
| 206 | + # check split model is valid |
| 207 | + try: |
| 208 | + ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers) |
| 209 | + except Exception as e: |
| 210 | + logger.error( |
| 211 | + "Layer-wise quantized model {} can't be inferred correctly. " |
| 212 | + "Please check the raise exception".format(split_idx) |
| 213 | + ) |
| 214 | + raise e |
| 215 | + |
| 216 | + # merge split quantized model |
| 217 | + if quantized_model_merged is None: |
| 218 | + quantized_model_merged = split_model_part_2_quantized |
| 219 | + quantized_model_merged.write_external_data_to_new_location(overwrite=True) |
| 220 | + else: |
| 221 | + quantized_model_merged.merge_split_models(split_model_part_2_quantized) |
| 222 | + |
| 223 | + # reload external data to prevent external data file path errors |
| 224 | + from onnx.external_data_helper import load_external_data_for_model |
| 225 | + |
| 226 | + load_external_data_for_model(quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path)) |
| 227 | + |
| 228 | + return quantized_model_merged |
| 229 | + |
| 230 | + |
| 231 | +class DataReader(CalibrationDataReader): |
| 232 | + """Data reader for layer-wise quantization.""" |
| 233 | + |
| 234 | + def __init__(self, data_list): |
| 235 | + self.data_list = data_list |
| 236 | + self.iter_next = iter(self.data_list) |
| 237 | + |
| 238 | + def get_next(self): |
| 239 | + return next(self.iter_next, None) |
| 240 | + |
| 241 | + def rewind(self): |
| 242 | + self.iter_next = iter(self.data_list) |
| 243 | + |
| 244 | + |
| 245 | +def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_reader: CalibrationDataReader): |
| 246 | + """Filter data reader to remove data that is not in model input. |
| 247 | +
|
| 248 | + Args: |
| 249 | + model (onnx.ModelProto): onnx model. |
| 250 | + data_reader (CalibrationDataReader): data reader. |
| 251 | +
|
| 252 | + Returns: |
| 253 | + CalibrationDataReader: filtered data reader. |
| 254 | + """ |
| 255 | + filter_inputs = [] |
| 256 | + input_names = [input.name for input in model.graph.input] |
| 257 | + while True: |
| 258 | + inputs = data_reader.get_next() |
| 259 | + if not inputs: |
| 260 | + break |
| 261 | + filter_input = { |
| 262 | + input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names |
| 263 | + } |
| 264 | + filter_inputs.append(filter_input) |
| 265 | + return DataReader(filter_inputs) |
| 266 | + |
| 267 | + |
| 268 | +def _prepare_data_reader_for_next_split_model( |
| 269 | + model_path: str, |
| 270 | + data_reader: CalibrationDataReader, |
| 271 | + providers: List[str] = ["CPUExecutionProvider"], |
| 272 | +): |
| 273 | + """Prepare data reader for next split model. |
| 274 | +
|
| 275 | + Get data output of current split model and save for next split model. |
| 276 | +
|
| 277 | + Args: |
| 278 | + model (str): path to onnx model. |
| 279 | + data_reader (CalibrationDataReader): data reader |
| 280 | + providers (List[str], optional): providers to use. Defaults to ["CPUExecutionProvider"]. |
| 281 | +
|
| 282 | + Returns: |
| 283 | + CalibrationDataReader: data reader for next split model. |
| 284 | + """ |
| 285 | + data_reader = deepcopy(data_reader) |
| 286 | + |
| 287 | + data_reader_for_next_split_model = [] |
| 288 | + session = ort.InferenceSession(model_path, providers=providers) |
| 289 | + output_names = [output.name for output in session.get_outputs()] |
| 290 | + while True: |
| 291 | + inputs = data_reader.get_next() |
| 292 | + if not inputs: |
| 293 | + break |
| 294 | + out = session.run(None, inputs) |
| 295 | + inputs.update({name: value for name, value in zip(output_names, out)}) |
| 296 | + data_reader_for_next_split_model.append(inputs) |
| 297 | + return DataReader(data_reader_for_next_split_model) |
0 commit comments