diff --git a/neural_compressor/onnxrt/algorithms/__init__.py b/neural_compressor/onnxrt/algorithms/__init__.py index d40c1e41d0c..c1d38b1844c 100644 --- a/neural_compressor/onnxrt/algorithms/__init__.py +++ b/neural_compressor/onnxrt/algorithms/__init__.py @@ -17,5 +17,6 @@ from neural_compressor.onnxrt.algorithms.weight_only.rtn import apply_rtn_on_model from neural_compressor.onnxrt.algorithms.weight_only.gptq import apply_gptq_on_model from neural_compressor.onnxrt.algorithms.weight_only.awq import apply_awq_on_model +from neural_compressor.onnxrt.algorithms.layer_wise import layer_wise_quant -__all__ = ["Smoother", "apply_rtn_on_model", "apply_gptq_on_model", "apply_awq_on_model"] +__all__ = ["Smoother", "apply_rtn_on_model", "apply_gptq_on_model", "apply_awq_on_model", "layer_wise_quant"] diff --git a/neural_compressor/onnxrt/algorithms/layer_wise/__init__.py b/neural_compressor/onnxrt/algorithms/layer_wise/__init__.py new file mode 100644 index 00000000000..86c5371fbb3 --- /dev/null +++ b/neural_compressor/onnxrt/algorithms/layer_wise/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.onnxrt.algorithms.layer_wise.core import layer_wise_quant + +__all__ = ["layer_wise_quant"] diff --git a/neural_compressor/onnxrt/algorithms/layer_wise/core.py b/neural_compressor/onnxrt/algorithms/layer_wise/core.py new file mode 100644 index 00000000000..f6f88b63b78 --- /dev/null +++ b/neural_compressor/onnxrt/algorithms/layer_wise/core.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 MIT HAN Lab +# This source code is licensed under the MIT license +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from copy import deepcopy +from pathlib import Path +from typing import Callable, List, Union + +import onnx +import onnxruntime as ort +import transformers +from packaging.version import Version + +from neural_compressor.common import Logger +from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader +from neural_compressor.onnxrt.utils.onnx_model import ONNXModel +from neural_compressor.onnxrt.utils.utility import check_model_with_infer_shapes + +logger = Logger().get_logger() + +__all__ = [ + "layer_wise_quant", +] + + +def layer_wise_quant( + model: Union[onnx.ModelProto, ONNXModel, Path, str], + quant_func: Callable, + weight_config: dict, + data_reader: CalibrationDataReader = None, + *args, + **kwargs +) -> ONNXModel: + """Quantize model layer by layer to save memory. + + Args: + model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model. + quant_func (Callable): quantization algo function. + weight_config (dict): quantization config. + data_reader (CalibrationDataReader, optional): data_reader for calibration. Defaults to None. + + Returns: + _type_: _description_ + """ + # TODO: remove the limitation for lwq + if Version(transformers.__version__) > Version("4.37.2"): + logger.warning( + "Model (such as llama-2) exported with transformers {} may fail in layer-wise quant. " + "we recommend downgrading transformers to 4.37.2 and try again.".format(transformers.__version__) + ) + + # check whether model shape is inferred + if not check_model_with_infer_shapes(model): + logger.error( + "Before applying layer-wise quantization, please make sure to " + "run symbolic shape inference on your model like follows:\n" + "import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n" + "model = onnx.load(your_model_path)\n" + "out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n" + "onnx.save(out, infer_shape_model_path)\n" + ) + raise ValueError("Fail to run layer-wise quantization.") + + if not isinstance(model, ONNXModel): + model = ONNXModel(model, ignore_warning=True, load_external_data=False) + + origin_model = deepcopy(model) + + providers = kwargs.get("providers", ["CPUExecutionProvider"]) + + # get and check split nodes + split_nodes = origin_model.find_split_nodes() + if len(split_nodes) == 0: + logger.error( + "Can't find split nodes for layer-wise quantization. " + "We recommend applying graph optimization for your model like follows: \n" + "import onnxruntime as ort \n" + "sess_options = ort.SessionOptions() \n" + "sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED " + "# or ORT_ENABLE_BASIC \n" + "sess_options.optimized_model_filepath = 'optimized_model_path' \n" + "ort.InferenceSession(infer_shape_model_path, sess_options)" + ) + raise ValueError("Fail to run layer-wise quantization.") + logger.info( + "Will split model into {} parts to do layer-wise quantization".format( + len([node.name for node in split_nodes]) + 1 + ) + ) + logger.debug( + "Will split model with these nodes for layer-wise quantization: {}".format([node.name for node in split_nodes]) + ) + + split_idx = 1 + model_to_split = [origin_model] + quantized_model_merged = None + + require_data_reader = data_reader is not None + if require_data_reader: + lwq_data_reader = [data_reader] + + while len(model_to_split) != 0: + # prepare model, node and data_reader for current split + split_model = model_to_split.pop(0) + split_node = split_nodes.pop(0) + if require_data_reader: + current_data_reader = lwq_data_reader.pop(0) + + # if no remaining split nodes, it means this is the last split, and the two split models will be saved. + save_both_split_models = True if len(split_nodes) == 0 else False + + # split model with given split node + split_model_part_1, split_model_part_2 = split_model.split_model_with_node( + split_node.name, model.model_path, save_both_split_models + ) + if not save_both_split_models: + # append split_model_part_2 to do next split + model_to_split.append(split_model_part_2) + + logger.info("Quantize split model {}".format(split_idx)) + if require_data_reader: + # process data_reader for current split and next split + current_data_reader = _filter_data_reader_for_current_split_model( + split_model_part_1.model, current_data_reader + ) + next_data_reader = _prepare_data_reader_for_next_split_model( + split_model_part_1.model_path, current_data_reader, providers + ) + lwq_data_reader.append(next_data_reader) + + # perform quantization + split_model_part_1_quantized = quant_func( + split_model_part_1, + weight_config=weight_config, + data_reader=current_data_reader, + return_modelproto=False, + **kwargs + ) + else: + # perform quantization + split_model_part_1_quantized = quant_func( + split_model_part_1, weight_config=weight_config, return_modelproto=False, **kwargs + ) + + # check split model is valid + try: + ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers) + except Exception as e: + logger.error( + "Layer-wise quantized model {} can't be inferred correctly. " + "Please check the raise exception".format(split_idx) + ) + raise e + + # merge split quantized model + if quantized_model_merged is None: + quantized_model_merged = split_model_part_1_quantized + quantized_model_merged.write_external_data_to_new_location(overwrite=True) + else: + quantized_model_merged.merge_split_models(split_model_part_1_quantized) + + split_idx += 1 + # if this is the last split, quantize the last split model + if save_both_split_models: + logger.info("Quantize split model {}".format(split_idx)) + + # quantize split model + if require_data_reader: + # process data_reader for current split + current_data_reader = lwq_data_reader.pop(0) + current_data_reader = _filter_data_reader_for_current_split_model( + split_model_part_2.model, current_data_reader + ) + + # perform quantization + split_model_part_2_quantized = quant_func( + split_model_part_2, + weight_config=weight_config, + data_reader=current_data_reader, + return_modelproto=False, + **kwargs + ) + else: + # perform quantization + split_model_part_2_quantized = quant_func( + split_model_part_2, weight_config=weight_config, return_modelproto=False, **kwargs + ) + + # check split model is valid + try: + ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers) + except Exception as e: + logger.error( + "Layer-wise quantized model {} can't be inferred correctly. " + "Please check the raise exception".format(split_idx) + ) + raise e + + # merge split quantized model + if quantized_model_merged is None: + quantized_model_merged = split_model_part_2_quantized + quantized_model_merged.write_external_data_to_new_location(overwrite=True) + else: + quantized_model_merged.merge_split_models(split_model_part_2_quantized) + + # reload external data to prevent external data file path errors + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path)) + + return quantized_model_merged + + +class DataReader(CalibrationDataReader): + """Data reader for layer-wise quantization.""" + + def __init__(self, data_list): + self.data_list = data_list + self.iter_next = iter(self.data_list) + + def get_next(self): + return next(self.iter_next, None) + + def rewind(self): + self.iter_next = iter(self.data_list) + + +def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_reader: CalibrationDataReader): + """Filter data reader to remove data that is not in model input. + + Args: + model (onnx.ModelProto): onnx model. + data_reader (CalibrationDataReader): data reader. + + Returns: + CalibrationDataReader: filtered data reader. + """ + filter_inputs = [] + input_names = [input.name for input in model.graph.input] + while True: + inputs = data_reader.get_next() + if not inputs: + break + filter_input = { + input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names + } + filter_inputs.append(filter_input) + return DataReader(filter_inputs) + + +def _prepare_data_reader_for_next_split_model( + model_path: str, + data_reader: CalibrationDataReader, + providers: List[str] = ["CPUExecutionProvider"], +): + """Prepare data reader for next split model. + + Get data output of current split model and save for next split model. + + Args: + model (str): path to onnx model. + data_reader (CalibrationDataReader): data reader + providers (List[str], optional): providers to use. Defaults to ["CPUExecutionProvider"]. + + Returns: + CalibrationDataReader: data reader for next split model. + """ + data_reader = deepcopy(data_reader) + + data_reader_for_next_split_model = [] + session = ort.InferenceSession(model_path, providers=providers) + output_names = [output.name for output in session.get_outputs()] + while True: + inputs = data_reader.get_next() + if not inputs: + break + out = session.run(None, inputs) + inputs.update({name: value for name, value in zip(output_names, out)}) + data_reader_for_next_split_model.append(inputs) + return DataReader(data_reader_for_next_split_model) diff --git a/neural_compressor/onnxrt/algorithms/weight_only/awq.py b/neural_compressor/onnxrt/algorithms/weight_only/awq.py index 44dd8839ee1..647d0a9d25e 100644 --- a/neural_compressor/onnxrt/algorithms/weight_only/awq.py +++ b/neural_compressor/onnxrt/algorithms/weight_only/awq.py @@ -275,7 +275,7 @@ def _apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, def awq_quantize( model: Union[onnx.ModelProto, ONNXModel, Path, str], - dataloader: CalibrationDataReader, + data_reader: CalibrationDataReader, weight_config: dict = {}, num_bits: int = 4, group_size: int = 32, @@ -289,7 +289,7 @@ def awq_quantize( Args: model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model. - dataloader (CalibrationDataReader): dataloader for calibration. + data_reader (CalibrationDataReader): data_reader for calibration. weight_config (dict, optional): quantization config For example, weight_config = { @@ -323,8 +323,8 @@ def awq_quantize( full_ratio = {} if enable_mse_search: - inputs, so = prepare_inputs(model, dataloader, providers) - del dataloader + inputs, so = prepare_inputs(model, data_reader, providers) + del data_reader org_output = copy.deepcopy(model.model.graph.output) model.remove_tensors_from_outputs([i.name for i in org_output]) @@ -420,7 +420,7 @@ def apply_awq_on_model( Args: model (Union[onnx.ModelProto, ONNXModel, Path, str]): nnx model. quant_config (dict): quantization config. - calibration_data_reader (CalibrationDataReader): dataloader for calibration. + calibration_data_reader (CalibrationDataReader): data_reader for calibration. Returns: onnx.ModelProto: quantized onnx model. @@ -434,4 +434,4 @@ def apply_awq_on_model( if isinstance(op_config, AWQConfig): quant_config[op_name_type] = op_config.to_dict() - return awq_quantize(model, dataloader=calibration_data_reader, weight_config=quant_config, **kwargs) + return awq_quantize(model, data_reader=calibration_data_reader, weight_config=quant_config, **kwargs) diff --git a/neural_compressor/onnxrt/algorithms/weight_only/gptq.py b/neural_compressor/onnxrt/algorithms/weight_only/gptq.py index 8ddb0f15023..5a8985f1b0f 100644 --- a/neural_compressor/onnxrt/algorithms/weight_only/gptq.py +++ b/neural_compressor/onnxrt/algorithms/weight_only/gptq.py @@ -193,7 +193,7 @@ def find_params(weight): def gptq_quantize( model: Union[onnx.ModelProto, ONNXModel, Path, str], - dataloader: CalibrationDataReader, + data_reader: CalibrationDataReader, weight_config: dict = {}, num_bits: int = 4, group_size: int = 32, @@ -205,12 +205,13 @@ def gptq_quantize( perchannel: bool = True, accuracy_level: int = 0, providers: List[str] = ["CPUExecutionProvider"], -) -> onnx.ModelProto: + return_modelproto: bool = True, +): """Quant the model with GPTQ method. Args: model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model. - dataloader (CalibrationDataReader): dataloader for calibration. + data_reader (CalibrationDataReader): data_reader for calibration. weight_config (dict, optional): quantization config For example, weight_config = { @@ -236,6 +237,8 @@ def gptq_quantize( 1(fp32 compute type of jblas kernel), 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), 4 (int8 compute type of jblas kernel). Defaults to 0. providers (list, optional): providers to use. Defaults to ["CPUExecutionProvider"]. + return_modelproto (bool, optionmal): whether to return onnx.Modelproto. set False for layer-wise quant. + Default to True Returns: onnx.ModelProto: quantized onnx model @@ -244,8 +247,8 @@ def gptq_quantize( model = ONNXModel(model) base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" - inputs, so = prepare_inputs(model, dataloader, providers) - del dataloader + inputs, so = prepare_inputs(model, data_reader, providers) + del data_reader org_output = copy.deepcopy(model.model.graph.output) model.remove_tensors_from_outputs([i.name for i in org_output]) output_names = [] @@ -395,7 +398,10 @@ def gptq_quantize( load_external_data_for_model(model.model, os.path.split(model.model_path)[0]) - return model.model + if return_modelproto: + return model.model + else: + return model def apply_gptq_on_model( @@ -408,18 +414,38 @@ def apply_gptq_on_model( Args: model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model. quant_config (dict): quantization config. - calibration_data_reader (CalibrationDataReader): dataloader for calibration. + calibration_data_reader (CalibrationDataReader): data_reader for calibration. Returns: onnx.ModelProto: quantized onnx model. """ - # set model params - kwargs = {} - kwargs = {key: quant_config.pop(key) for key in GPTQConfig.model_params_list if key in quant_config} + # check whether to do layer_wise quant + layer_wise = quant_config.pop("layer_wise_quant", False) + + # set other model params + quant_kwargs = {} + quant_kwargs = {key: quant_config.pop(key) for key in GPTQConfig.model_params_list if key in quant_config} # change op config to dict type for op_name_type, op_config in quant_config.items(): if isinstance(op_config, GPTQConfig): quant_config[op_name_type] = op_config.to_dict() - return gptq_quantize(model, dataloader=calibration_data_reader, weight_config=quant_config, **kwargs) + if layer_wise: + from neural_compressor.onnxrt.algorithms import layer_wise_quant + + quantized_model = layer_wise_quant( + model, + quant_func=gptq_quantize, + weight_config=quant_config, + data_reader=calibration_data_reader, + **quant_kwargs + ) + else: + quantized_model = gptq_quantize( + model, data_reader=calibration_data_reader, weight_config=quant_config, **quant_kwargs + ) + + if isinstance(quantized_model, ONNXModel): + quantized_model = quantized_model.model + return quantized_model diff --git a/neural_compressor/onnxrt/algorithms/weight_only/rtn.py b/neural_compressor/onnxrt/algorithms/weight_only/rtn.py index 66da957a6bc..c4ee941bf17 100644 --- a/neural_compressor/onnxrt/algorithms/weight_only/rtn.py +++ b/neural_compressor/onnxrt/algorithms/weight_only/rtn.py @@ -55,7 +55,8 @@ def rtn_quantize( ratios: dict = {}, accuracy_level: int = 0, providers: List[str] = ["CPUExecutionProvider"], -) -> onnx.ModelProto: + return_modelproto: bool = True, +): """Quantize the model with round to nearst method. Args: @@ -81,7 +82,8 @@ def rtn_quantize( 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), 4 (int8 compute type of jblas kernel). Defaults to 0. providers (list, optional): providers to use. Defaults to ["CPUExecutionProvider"]. - + return_modelproto (bool, optionmal): whether to return onnx.Modelproto. set False for layer-wise quant. + Default to True Returns: onnx.ModelProto: quantized onnx model. """ @@ -180,25 +182,41 @@ def rtn_quantize( load_external_data_for_model(model.model, os.path.split(model.model_path)[0]) - return model.model + if return_modelproto: + return model.model + else: + return model -def apply_rtn_on_model(model: onnx.ModelProto, quant_config: dict) -> onnx.ModelProto: +def apply_rtn_on_model(model: Union[onnx.ModelProto, ONNXModel, Path, str], quant_config: dict) -> onnx.ModelProto: """Apply RTN on onnx model. Args: - model (onnx.ModelProto): onnx model. + model (Union[onnx.ModelProto, ONNXModel, Path, str]): onnx model. quant_config (dict): quantization config. Returns: onnx.ModelProto: quantized onnx model. """ - if "providers" in quant_config: - providers = quant_config.pop("providers") + # check whether to do layer_wise quant + layer_wise = quant_config.pop("layer_wise_quant", False) + + # set other model params + quant_kwargs = {} + quant_kwargs = {key: quant_config.pop(key) for key in RTNConfig.model_params_list if key in quant_config} # change op config to dict type for op_name_type, op_config in quant_config.items(): if isinstance(op_config, RTNConfig): quant_config[op_name_type] = op_config.to_dict() - return rtn_quantize(model, weight_config=quant_config, providers=providers) + if layer_wise: + from neural_compressor.onnxrt.algorithms import layer_wise_quant + + quantized_model = layer_wise_quant(model, quant_func=rtn_quantize, weight_config=quant_config, **quant_kwargs) + else: + quantized_model = rtn_quantize(model, weight_config=quant_config, **quant_kwargs) + + if isinstance(quantized_model, ONNXModel): + quantized_model = quantized_model.model + return quantized_model diff --git a/neural_compressor/onnxrt/algorithms/weight_only/utility.py b/neural_compressor/onnxrt/algorithms/weight_only/utility.py index d5a2d80a719..f69f8d57fab 100644 --- a/neural_compressor/onnxrt/algorithms/weight_only/utility.py +++ b/neural_compressor/onnxrt/algorithms/weight_only/utility.py @@ -221,14 +221,6 @@ def prepare_inputs(model, data_reader, providers): convert_attribute=False, ) - session = ( - ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) - if not model.is_large_model - else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) - ) - inputs_names = [i.name for i in session.get_inputs()] - del session - inputs_list = [] while True: inputs = data_reader.get_next() diff --git a/neural_compressor/onnxrt/quantization/config.py b/neural_compressor/onnxrt/quantization/config.py index 3e74d83720b..88a0a56171f 100644 --- a/neural_compressor/onnxrt/quantization/config.py +++ b/neural_compressor/onnxrt/quantization/config.py @@ -19,7 +19,7 @@ from collections import OrderedDict from enum import Enum from pathlib import Path -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Callable, List, NamedTuple, Union import numpy as np import onnx @@ -71,7 +71,10 @@ class RTNConfig(BaseConfig): "act_dtype", "accuracy_level", ] - model_params_list: List[str] = ["providers"] + model_params_list: List[str] = [ + "providers", + "layer_wise_quant", + ] name: str = RTN def __init__( @@ -83,6 +86,7 @@ def __init__( act_dtype: str = "fp32", accuracy_level: int = 0, providers: List[str] = ["CPUExecutionProvider"], + layer_wise_quant: bool = False, white_list: List[OP_NAME_OR_MODULE_TYPE] = DEFAULT_WHITE_LIST, ): """Init RTN weight-only quantization config. @@ -97,6 +101,10 @@ def __init__( 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), 4 (int8 compute type of jblas kernel). Defaults to 0. providers (list, optional): execution providers to use. Defaults to ["CPUExecutionProvider"]. + layer_wise_quant (bool, optional): whether to quantize model layer by layer to save memory footprint. + Check below link for details + https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_layer_wise.md, + default is False. white_list (list, optional): op in white_list will be applied current config. Defaults to DEFAULT_WHITE_LIST. """ @@ -108,6 +116,7 @@ def __init__( self.act_dtype = act_dtype self.accuracy_level = accuracy_level self.providers = providers + self.layer_wise_quant = layer_wise_quant self._post_init() def get_model_params_dict(self): @@ -154,7 +163,7 @@ def to_config_mapping(self, config_list: List[BaseConfig] = None, model_info: li @staticmethod def get_model_info(model: Union[onnx.ModelProto, Path, str]) -> list: if not isinstance(model, onnx.ModelProto): - model = onnx.load(model) + model = onnx.load(model, load_external_data=False) white_list = ["MatMul"] filter_result = [] for node in model.graph.node: @@ -202,6 +211,7 @@ class GPTQConfig(BaseConfig): "mse", "perchannel", "providers", + "layer_wise_quant", ] name: str = GPTQ @@ -219,6 +229,7 @@ def __init__( mse: bool = False, perchannel: bool = True, providers: List[str] = ["CPUExecutionProvider"], + layer_wise_quant: bool = False, white_list: List[OP_NAME_OR_MODULE_TYPE] = DEFAULT_WHITE_LIST, ): """Init GPTQ weight-only quantization config. @@ -240,6 +251,10 @@ def __init__( mse (bool, optional): whether get scale and zero point with mse error. Defaults to False. perchannel (bool, optional): whether quantize weight per-channel. Defaults to True. providers (list, optional): execution providers to use. Defaults to ["CPUExecutionProvider"]. + layer_wise_quant (bool, optional): whether to quantize model layer by layer to save memory footprint. + Check below link for details + https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_layer_wise.md, + default is False. white_list (list, optional): op in white_list will be applied current config. Defaults to DEFAULT_WHITE_LIST. """ @@ -256,6 +271,7 @@ def __init__( self.mse = mse self.perchannel = perchannel self.providers = providers + self.layer_wise_quant = layer_wise_quant self._post_init() def get_model_params_dict(self): @@ -305,7 +321,7 @@ def to_config_mapping(self, config_list: list = None, model_info: list = None) - @staticmethod def get_model_info(model: Union[onnx.ModelProto, Path, str]) -> list: if not isinstance(model, onnx.ModelProto): - model = onnx.load(model) + model = onnx.load(model, load_external_data=False) white_list = ["MatMul"] filter_result = [] for node in model.graph.node: @@ -449,7 +465,7 @@ def to_config_mapping(self, config_list: list = None, model_info: list = None) - @staticmethod def get_model_info(model: Union[onnx.ModelProto, Path, str]) -> list: if not isinstance(model, onnx.ModelProto): - model = onnx.load(model) + model = onnx.load(model, load_external_data=False) white_list = ["MatMul"] filter_result = [] for node in model.graph.node: diff --git a/neural_compressor/onnxrt/utils/onnx_model.py b/neural_compressor/onnxrt/utils/onnx_model.py index c8bfc71f5e5..56d45ba7fce 100644 --- a/neural_compressor/onnxrt/utils/onnx_model.py +++ b/neural_compressor/onnxrt/utils/onnx_model.py @@ -21,7 +21,6 @@ from onnxruntime.quantization.onnx_model import ONNXModel as ORTONNXModel from neural_compressor.common import Logger -from neural_compressor.onnxrt.utils.utility import MAXIMUM_PROTOBUF, find_by_name logger = Logger().get_logger() @@ -74,6 +73,8 @@ def model_path(self, path): def check_is_large_model(self): """Check model > 2GB.""" + from neural_compressor.onnxrt.utils.utility import MAXIMUM_PROTOBUF + init_size = 0 for init in self.model.graph.initializer: # if initializer has external data location, return True @@ -420,6 +421,8 @@ def get_nodes_chain(self, start, stop, result_chain=[]): from onnx import NodeProto + from neural_compressor.onnxrt.utils.utility import find_by_name + # process start node list start_node = deque() for node in start: @@ -499,7 +502,7 @@ def find_split_node_for_layer_wise_quantization(self): start_node, ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], [None, 0, 0, 0, 0, 0], - output_name_to_node=self.output_name_to_node, + output_name_to_node_dict=self._output_name_to_node, return_indice=[], ), # match bart attention structure @@ -579,7 +582,7 @@ def find_qkv_in_attention(self, find_all=False): start_node, ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], [None, 0, 0, 0, 0, 0], - output_name_to_node=self.output_name_to_node, + output_name_to_node_dict=self._output_name_to_node, return_indice=[], ), # match bart attention structure @@ -601,7 +604,7 @@ def find_qkv_in_attention(self, find_all=False): qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1] other_inputs = [] for input in start_node.input: - if input not in self.output_name_to_node: + if input not in self._output_name_to_node: continue if input == qkv_nodes[0].output[0]: continue @@ -689,7 +692,7 @@ def remove_tensors_from_outputs(self, tensor_names): for output in removed_outputs: self.model.graph.output.remove(output) - def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]): + def match_first_parent(self, node, parent_op_type, output_name_to_node_dict, exclude=[]): """Find parent node based on constraints on op_type. Args: @@ -703,8 +706,8 @@ def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude= index: The input index of matched parent node. None if not found. """ for i, input in enumerate(node.input): - if input in output_name_to_node: - parent = output_name_to_node[input] + if input in output_name_to_node_dict: + parent = output_name_to_node_dict[input] if parent.op_type == parent_op_type and parent not in exclude: return parent, i return None, None @@ -714,7 +717,7 @@ def match_parent( node, parent_op_type, input_index=None, - output_name_to_node=None, + output_name_to_node_dict=None, exclude=[], return_indice=None, ): @@ -734,13 +737,13 @@ def match_parent( assert node is not None assert input_index is None or input_index >= 0 - if output_name_to_node is None: + if output_name_to_node_dict is None: if len(self._output_name_to_node) == 0: self._output_name_to_node = self.output_name_to_node() - output_name_to_node = self._output_name_to_node + output_name_to_node_dict = self._output_name_to_node if input_index is None: - parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) + parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node_dict, exclude) if return_indice is not None: return_indice.append(index) return parent @@ -748,7 +751,7 @@ def match_parent( if input_index >= len(node.input): return None - parent = self.get_parent(node, input_index, output_name_to_node) + parent = self.get_parent(node, input_index, output_name_to_node_dict) if parent is not None and parent.op_type == parent_op_type and parent not in exclude: return parent @@ -759,7 +762,7 @@ def match_parent_path( node, parent_op_types, parent_input_index, - output_name_to_node=None, + output_name_to_node_dict=None, return_indice=None, ): """Find a sequence of input edges based on constraints on parent op_type and index. @@ -778,10 +781,10 @@ def match_parent_path( """ assert len(parent_input_index) == len(parent_op_types) - if output_name_to_node is None: + if output_name_to_node_dict is None: if len(self._output_name_to_node) == 0: self._output_name_to_node = self.output_name_to_node() - output_name_to_node = self._output_name_to_node + output_name_to_node_dict = self._output_name_to_node current_node = node matched_parents = [] @@ -790,7 +793,7 @@ def match_parent_path( current_node, op_type, parent_input_index[i], - output_name_to_node, + output_name_to_node_dict, exclude=[], return_indice=return_indice, ) @@ -818,15 +821,12 @@ def find_split_nodes(self): split_nodes = self.find_split_node_for_layer_wise_quantization() return split_nodes - def split_model_with_node( - self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True - ): + def split_model_with_node(self, split_node_name, path_of_model_to_split, save_both_split_models=True): """Split model into two parts at a given node. Args: split_node_name (str): name of the node where the model is split at> path_of_model_to_split (str): path of model to be split. - shape_infer (bool): do shape inference. Default is True. save_both_split_models (bool): whether to save the two split models. False means only save the first split model. True means save both the two split models. @@ -865,21 +865,6 @@ def split_model_with_node( ) split_tensor_name = split_node_output[0] - # infer shape of the model to be split - if shape_infer: - try: - from neural_compressor.adaptor.ox_utils.util import infer_shapes - - self.model = infer_shapes(self.model, auto_merge=True, base_dir=os.path.dirname(self._model_path)) - except Exception as e: # pragma: no cover - logger.error( - "Shape infer fails for layer-wise quantization. " - "We would recommend checking the graph optimization level of your model " - "and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', " - "as this may help avoid this error." - ) - raise e - split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name) split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape) @@ -895,8 +880,8 @@ def split_model_with_node( insert_output_for_model_1 = [] insert_input_for_model_2 = [] - for output in split_model_part_1.output_name_to_node.keys(): - if output in split_model_part_2.input_name_to_nodes.keys(): + for output in split_model_part_1._output_name_to_node.keys(): + if output in split_model_part_2._input_name_to_nodes.keys(): output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) if output_tensor not in split_model_part_1.model.graph.output: @@ -984,11 +969,11 @@ def _remove_unused_input_output(self): if len(self._input_name_to_nodes) == 0: self._input_name_to_nodes = self.input_name_to_nodes() for output in self.model.graph.output: - if output.name not in self.output_name_to_node.keys(): + if output.name not in self._output_name_to_node.keys(): remove_outputs.append(output) for input in self.model.graph.input: - if input.name not in self.input_name_to_nodes.keys(): + if input.name not in self._input_name_to_nodes.keys(): remove_inputs.append(input) for output in remove_outputs: @@ -1002,7 +987,7 @@ def remove_unused_init(self): if len(self._input_name_to_nodes) == 0: self._input_name_to_nodes = self.input_name_to_nodes() for init in self.model.graph.initializer: - if init.name not in self.input_name_to_nodes.keys(): + if init.name not in self._input_name_to_nodes.keys(): remov_inits.append(init) self.remove_initializers(remov_inits) @@ -1062,7 +1047,7 @@ def merge_split_models(self, to_merge_model): if ( input.name not in self.input() and input.name not in self.output() - and input.name not in self.output_name_to_node.keys() + and input.name not in self._output_name_to_node.keys() ): self.model.graph.input.append(input) diff --git a/neural_compressor/onnxrt/utils/utility.py b/neural_compressor/onnxrt/utils/utility.py index a31704fb2f2..21678717229 100644 --- a/neural_compressor/onnxrt/utils/utility.py +++ b/neural_compressor/onnxrt/utils/utility.py @@ -17,6 +17,7 @@ import numpy as np import onnx +import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer from packaging.version import Version from neural_compressor.common import Logger @@ -41,6 +42,7 @@ "is_B_transposed", "get_qrange_for_qType", "quantize_data", + "check_model_with_infer_shapes", ] ONNXRT116_VERSION = Version("1.16.0") @@ -271,3 +273,16 @@ def quantize_data(data, quantize_range, qType, scheme): scale, zero_point = _calculate_scale_zp(rmin, rmax, quantize_range, qType, scheme) quantized_data = _quantize_data_with_scale_zero(data, qType, scheme, scale, zero_point) return rmin, rmax, zero_point, scale, quantized_data + + +def check_model_with_infer_shapes(model): + """Check if the model has been shape inferred.""" + from neural_compressor.onnxrt.utils.onnx_model import ONNXModel + + if isinstance(model, (Path, str)): + model = onnx.load(model, load_external_data=False) + elif isinstance(model, ONNXModel): + model = model.model + if len(model.graph.value_info) > 0: + return True + return False diff --git a/test/3x/onnxrt/quantization/layer_wise/test_layer_wise.py b/test/3x/onnxrt/quantization/layer_wise/test_layer_wise.py new file mode 100644 index 00000000000..c8e7584ee7f --- /dev/null +++ b/test/3x/onnxrt/quantization/layer_wise/test_layer_wise.py @@ -0,0 +1,155 @@ +import os +import shutil +import unittest +from copy import deepcopy + +import onnx +import onnxruntime as ort +import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer +import torch +from optimum.exporters.onnx import main_export +from transformers import AutoTokenizer + +from neural_compressor.common import Logger +from neural_compressor.onnxrt.quantization.calibrate import CalibrationDataReader + +logger = Logger().get_logger() + + +def find_onnx_file(folder_path): + # return first .onnx file path in folder_path + for root, dirs, files in os.walk(folder_path): + for file in files: + if file.endswith(".onnx"): + return os.path.join(root, file) + return None + + +class DummyNLPDataloader(CalibrationDataReader): + def __init__(self, model_name): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.sequence_a = "intel-extension-for-transformers is based in SH" + self.sequence_b = "Where is intel-extension-for-transformers based? NYC or SH" + + self.encoded_list = [] + encoded_input = dict(self.tokenizer(self.sequence_a, self.sequence_b, return_tensors="pt")) + input_shape = encoded_input["input_ids"].shape + encoded_input["position_ids"] = ( + torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) + ) + + # convert torch tensor to numpy + for input_name, input_value in encoded_input.items(): + if isinstance(input_value, torch.Tensor): + encoded_input[input_name] = input_value.numpy() + + self.encoded_list.append(encoded_input) + self.iter_next = iter(self.encoded_list) + + def get_next(self): + return next(self.iter_next, None) + + def rewind(self): + self.iter_next = iter(self.encoded_list) + + +class TestLayerWiseQuant(unittest.TestCase): + @classmethod + def setUpClass(self): + # onnx model exported with transformers>=4.38.0 is different with low version + # which will cause layer-wise quant ut to fail + # limit transformers to 4.37.2 + # TODO: remove transformers version limitation + llama_id = "yujiepan/llama-2-tiny-3layers-random" + main_export(llama_id, output="llama-2-tiny-3layers-random", task="text-generation") + model_path = find_onnx_file("llama-2-tiny-3layers-random") + + model = onnx.load(model_path) + model = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True) + infer_shape_model_path = "llama-2-tiny-3layers-random/model-infer-shape.onnx" + onnx.save(model, infer_shape_model_path) + + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED + sess_options.optimized_model_filepath = "llama-2-tiny-3layers-random/optimized_model.onnx" + ort.InferenceSession(infer_shape_model_path, sess_options) + + self.llama = "llama-2-tiny-3layers-random/optimized_model.onnx" + self.calibration_data_reader = DummyNLPDataloader(llama_id) + + @classmethod + def tearDownClass(self): + shutil.rmtree("llama-2-tiny-3layers-random", ignore_errors=True) + + def setUp(self): + # print the test name + logger.info(f"Running ONNXRT TestLayerWiseQuant test: {self.id()}") + + def _check_model_is_quantized(self, model): + node_optypes = [node.op_type for node in model.graph.node] + return "MatMulNBits" in node_optypes or "MatMulFpQ4" in node_optypes + + def _get_quantized_matmul_weight(self, model, matmul_name): + weight_init_name = None + for node in model.graph.node: + if node.name == matmul_name: + weight_init_name = node.input[1] + if weight_init_name is None: + return None + + weight_init = None + for init in model.graph.initializer: + if init.name == weight_init_name: + weight_init = onnx.numpy_helper.to_array(init) + return weight_init + + def _apply_quantize(self, quant_config, data_reader=None): + from neural_compressor.onnxrt.quantization.quantize import _quantize + + fp32_model = self.llama + if data_reader is None: + qmodel = _quantize(fp32_model, quant_config) + else: + qmodel = _quantize(fp32_model, quant_config, data_reader) + self.assertIsNotNone(qmodel) + return qmodel + + def test_rtn_layer_wise(self): + from neural_compressor.onnxrt.quantization import RTNConfig + + rtn_config = RTNConfig(layer_wise_quant=True) + qmodel_lwq = self._apply_quantize(rtn_config) + self.assertTrue(self._check_model_is_quantized(qmodel_lwq)) + + rtn_config = RTNConfig(layer_wise_quant=False) + qmodel = self._apply_quantize(rtn_config) + self.assertTrue(self._check_model_is_quantized(qmodel)) + + lwq_quantized_weight = self._get_quantized_matmul_weight(qmodel_lwq, "/lm_head/MatMul_Q4") + self.assertIsNotNone(lwq_quantized_weight) + quantized_weight = self._get_quantized_matmul_weight(qmodel, "/lm_head/MatMul_Q4") + self.assertIsNotNone(quantized_weight) + self.assertTrue((lwq_quantized_weight == quantized_weight).all()) + + def test_gptq_layer_wise(self): + from neural_compressor.onnxrt.quantization import GPTQConfig + + self.calibration_data_reader.rewind() + gptq_config = GPTQConfig(layer_wise_quant=True) + qmodel_lwq = self._apply_quantize(gptq_config, self.calibration_data_reader) + self.assertTrue(self._check_model_is_quantized(qmodel_lwq)) + + self.calibration_data_reader.rewind() + gptq_config = GPTQConfig(layer_wise_quant=False) + qmodel = self._apply_quantize(gptq_config, self.calibration_data_reader) + self.assertTrue(self._check_model_is_quantized(qmodel)) + + lwq_quantized_weight = self._get_quantized_matmul_weight(qmodel_lwq, "/lm_head/MatMul_Q4") + self.assertIsNotNone(lwq_quantized_weight) + quantized_weight = self._get_quantized_matmul_weight(qmodel, "/lm_head/MatMul_Q4") + self.assertIsNotNone(quantized_weight) + self.assertTrue((lwq_quantized_weight == quantized_weight).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/3x/onnxrt/requirements.txt b/test/3x/onnxrt/requirements.txt index 4165ba5e0a6..4a178c61854 100644 --- a/test/3x/onnxrt/requirements.txt +++ b/test/3x/onnxrt/requirements.txt @@ -1,2 +1,3 @@ optimum pytest +transformers==4.37.2 # limitation for test_layer_wise