diff --git a/.azure-pipelines/scripts/ut/run_itrex.sh b/.azure-pipelines/scripts/ut/run_itrex.sh index b7b7f34213c..d70da260694 100644 --- a/.azure-pipelines/scripts/ut/run_itrex.sh +++ b/.azure-pipelines/scripts/ut/run_itrex.sh @@ -6,6 +6,7 @@ echo "run itrex ut..." # prepare itrex git clone https://github.com/intel/intel-extension-for-transformers.git /intel-extension-for-transformers +cd /intel-extension-for-transformers && git rev-parse --short HEAD bash /intel-extension-for-transformers/.github/workflows/script/prepare_env.sh bash /intel-extension-for-transformers/.github/workflows/script/install_binary.sh diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 51a3f70c1dc..c94de80f215 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -118,12 +118,16 @@ def from_dict(cls, config_dict, str2operator=None): Returns: The constructed config. """ - config = cls(**config_dict.get(GLOBAL, {})) - operator_config = config_dict.get(LOCAL, {}) - if operator_config: - for op_name, op_config in operator_config.items(): - config.set_local(op_name, cls(**op_config)) - return config + if GLOBAL not in config_dict and LOCAL not in config_dict: + config = cls(**config_dict) + return config + else: + config = cls(**config_dict.get(GLOBAL, {})) + operator_config = config_dict.get(LOCAL, {}) + if operator_config: + for op_name, op_config in operator_config.items(): + config.set_local(op_name, cls(**op_config)) + return config @classmethod def to_diff_dict(cls, instance) -> Dict[str, Any]: @@ -201,11 +205,11 @@ def to_config_mapping( global_config = config.global_config op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() for op_name, op_type in model_info: - config_mapping.setdefault(op_type, OrderedDict())[op_name] = global_config + config_mapping[(op_type, op_name)] = global_config if op_type in op_type_config_dict: - config_mapping[op_type][op_name] = op_name_config_dict[op_type] + config_mapping[(op_type, op_name)] = op_name_config_dict[op_type] if op_name in op_name_config_dict: - config_mapping[op_type][op_name] = op_name_config_dict[op_name] + config_mapping[(op_type, op_name)] = op_name_config_dict[op_name] return config_mapping @staticmethod @@ -234,9 +238,15 @@ def to_dict(self, params_list=[], operator2str=None): return result @classmethod - def from_dict(cls, config_dict, str2operator=None): - # TODO(Yi) - pass + def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[str, BaseConfig]): + assert len(config_dict) >= 1, "The config dict must include at least one configuration." + num_configs = len(config_dict) + name, value = next(iter(config_dict.items())) + config = config_registry[name].from_dict(value) + for _ in range(num_configs - 1): + name, value = next(iter(config_dict.items())) + config += config_registry[name].from_dict(value) + return config def to_json_string(self, use_diff: bool = False) -> str: return json.dumps(self.to_dict(), indent=2) + "\n" diff --git a/neural_compressor/common/utility.py b/neural_compressor/common/utility.py index 51b37092033..d4287f09632 100644 --- a/neural_compressor/common/utility.py +++ b/neural_compressor/common/utility.py @@ -26,4 +26,5 @@ BASE_CONFIG = "base_config" COMPOSABLE_CONFIG = "composable_config" RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant" +GPTQ = "gptq" DUMMY_CONFIG = "dummy_config" diff --git a/neural_compressor/torch/__init__.py b/neural_compressor/torch/__init__.py index b8606e0b7f8..a0b414e2994 100644 --- a/neural_compressor/torch/__init__.py +++ b/neural_compressor/torch/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from neural_compressor.torch.utils import register_algo -from neural_compressor.torch.algorithms import rtn_quantize_entry +from neural_compressor.torch.algorithms import rtn_quantize_entry, gptq_quantize_entry from neural_compressor.torch.quantization import ( quantize, @@ -21,4 +21,6 @@ get_default_rtn_config, DummyConfig, get_default_dummy_config, + GPTQConfig, + get_default_gptq_config, ) diff --git a/neural_compressor/torch/algorithms/__init__.py b/neural_compressor/torch/algorithms/__init__.py index 94a7739ef89..ebb6e56ae35 100644 --- a/neural_compressor/torch/algorithms/__init__.py +++ b/neural_compressor/torch/algorithms/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. -from neural_compressor.torch.algorithms.rtn_quantize import rtn_quantize_entry +from neural_compressor.torch.algorithms.weight_only_algos import rtn_quantize_entry +from neural_compressor.torch.algorithms.weight_only_algos import gptq_quantize_entry diff --git a/neural_compressor/torch/algorithms/gptq.py b/neural_compressor/torch/algorithms/gptq.py new file mode 100644 index 00000000000..a42908e7385 --- /dev/null +++ b/neural_compressor/torch/algorithms/gptq.py @@ -0,0 +1,1043 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copied from neural_compressor/adaptor/torch_utils/gptq.py + +import gc +import math +import random +import re +import time +from collections import UserDict, defaultdict +from functools import partial + +import torch +import torch.nn as nn +import transformers +from tqdm import tqdm + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + + +DEBUG = False + + +# ================ device related =================== +def move_input_to_device(input, device=torch.device("cpu")): + if isinstance(input, dict) or isinstance(input, UserDict): + for inp in input.keys(): + input[inp] = input[inp].to(device) if isinstance(input[inp], torch.Tensor) else input[inp] + elif isinstance(input, list) or isinstance(input, tuple): + input_res, prev_size = [], None + for inp in input: + if prev_size: + if isinstance(inp, torch.Tensor): + if inp.size() == prev_size: + input_res.append(inp.to(device)) + else: + if torch.tensor(inp).size == prev_size: + input_res.append(inp) + else: + input_res.append(inp.to(device) if isinstance(inp, torch.Tensor) else inp) + prev_size = torch.tensor(inp).size() + input = input_res + else: + input = input.to(device) # pylint: disable=no-member + return input + + +# ==============model structure related============== +def is_leaf(module): + """Judge whether a module has no child-modules. + + Args: + module: torch.nn.Module + + Returns: + a bool: whether a module has no child-modules. + """ + children_cnt = 0 + for n in module.children(): + children_cnt += 1 + return True if children_cnt == 0 else False + + +def trace_gptq_target_blocks(module, module_types=[torch.nn.ModuleList, torch.nn.Sequential]): + """Search transformer stacked structures, which is critical in LLMs and GPTQ execution. + + Args: + module: torch.nn.Module + module_types: List of torch.nn.Module. + + Returns: + gptq_related_blocks = { + "embeddings": {}, # Dict embedding layers before transformer stack module, + "transformers_pre": {}, # TODO + "transformers_name": string. LLMs' transformer stack module name , + "transformers": torch.nn.ModuleList. LLMs' transformer stack module, + "transformers": {}, Dict# TODO + } + """ + if type(module).__name__ == "MixFormerSequentialForCausalLM": # pragma: no cover + gptq_related_blocks = { + "embeddings": {}, + "transformers_pre": {}, # todo + "transformers_name": "", # None + "transformers": [], # None + "transformers_post": {}, # todo + } + for n, m in module.named_modules(): + if type(m) in module_types: + gptq_related_blocks["transformers_name"] = n + gptq_related_blocks["transformers"] = m + break + else: + continue + for n, m in gptq_related_blocks["transformers"][0].named_modules(): + if is_leaf(m): + gptq_related_blocks["embeddings"][n] = m + gptq_related_blocks["transformers"] = gptq_related_blocks["transformers"][1:-1] + else: + gptq_related_blocks = { + "embeddings": {}, + "transformers_pre": {}, # todo + "transformers_name": "", # None + "transformers": [], # None + "transformers_post": {}, # todo + } + for n, m in module.named_modules(): + if type(m) in module_types: + gptq_related_blocks["transformers_name"] = n + gptq_related_blocks["transformers"] = m + return gptq_related_blocks + else: + if is_leaf(m): + gptq_related_blocks["embeddings"][n] = m + return gptq_related_blocks + + +def find_layers(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D], name=""): + """Get all layers with target types.""" + if type(module) in layers: + return {name: module} + else: + # use string type to find name: + if type(module).__name__ in ["Linear"]: + return {name: module} + else: + pass + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + return res + + +def find_layers_name(module, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D], name=""): + """Get all layers with target types.""" + if type(module) in layers: + return [name] + res = [] + for name1, child in module.named_children(): + res += find_layers_name(child, layers=layers, name=name + "." + name1 if name != "" else name1) + return res + + +def log_quantizable_layers_per_transformer( + transformer_blocks, layers=[nn.Conv2d, nn.Conv1d, nn.Linear, transformers.Conv1D] +): + """Print all layers which will be quantized in GPTQ algorithm.""" + logger.info("* * Layer to be quantized * *") + + for block_id in range(len(transformer_blocks["transformers"])): + transformer_block = transformer_blocks["transformers"][block_id] + layers_for_this_tblock = find_layers_name(transformer_block) + layer_names = [ + (transformer_blocks["transformers_name"] + "." + str(block_id) + "." + layer_name) + for layer_name in layers_for_this_tblock + ] + for name in layer_names: + logger.info(name) + + +# ===============quantization related============================ +def quantize(x, scale, zero, maxq): + """Do quantization.""" + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + +class GPTQuantizer(object): + """Main API for GPTQ algorithm. + + Please refer to: + GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers + url: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + model, + weight_config={}, + nsamples=128, + dataloader_len=10, + use_max_length=True, + pad_max_length=2048, + device=None, + layer_wise=False, + *args, + **kwargs, + ): + """ + Args: + model: the fp32 model to quantize + weight_config (dict, optional): contains all info required by GPTQ. Defaults to {}. For example, + weight_config={ + 'layer1': + { + 'bits': 4, + 'group_size': 32, + 'sym': False, + 'percdamp': .01, + 'act_order': False + } + ... + } + dataloader: an iterable containing calibration datasets, contains (inputs, targets) + device: cpu or cuda + """ + # model + self.model = model + # self.use_cache = self.model.config.use_cache + self.gptq_related_blocks = trace_gptq_target_blocks(self.model) # get the transformer block list above + self.dtype = next(iter(self.model.parameters())).dtype + log_quantizable_layers_per_transformer(self.gptq_related_blocks) + + # weight config + self.weight_config = weight_config + # default settings, check configs + self.wbits_default = 4 + self.group_size_default = 128 + self.block_size_default = 128 + self.percdamp_default = 0.01 + self.sym_default = False + self.act_order_default = False + self.perchannel_default = True + self.mse_default = False + self.check_layer_config() + + # device + self.device = device + if str(self.model.device).startswith("cuda"): + self.device = self.model.device + self.is_ready = False + + self.layer_wise = layer_wise + + # dataloader + self.use_max_length = use_max_length + self.pad_max_length = pad_max_length + self.dataloader_original = None + self.dataloader = [] + self.dataloader_len = dataloader_len + self.nsamples = nsamples + self.args = args + self.kwargs = kwargs + self.run_fn = self.kwargs.get("run_fn", None) + self.run_args = self.kwargs.get("run_args", None) + self.dataloader_len = dataloader_len + # compare 2.x, use run_fn to calibration + # self.prepare_dataloader() + self._post_init() + + def _post_init(self): + self.cache_key_arguments = { + "i": 0 + } # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.) + # Note that the first elements in cache_positional_arguments is main input: hidden_states + self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm) + self.is_ready = True + + def get_full_layer_name(self, sub_layer_name, block_idx): + transformer_name = self.gptq_related_blocks["transformers_name"] + return ".".join([transformer_name, str(block_idx), sub_layer_name]) + + def check_layer_config(self): + """Copy arguments from weight_config to built-in attributes.""" + if "wbits" in self.weight_config: + tmp_weight_config = {} + for name, module in self.model.named_modules(): + tmp_weight_config[name] = {} + tmp_weight_config[name]["wbits"] = self.weight_config.get("wbits", self.wbits_default) + tmp_weight_config[name]["group_size"] = self.weight_config.get("group_size", self.group_size_default) + tmp_weight_config[name]["block_size"] = self.weight_config.get("block_size", self.group_size_default) + tmp_weight_config[name]["percdamp"] = self.weight_config.get("pecdamp", self.percdamp_default) + tmp_weight_config[name]["sym"] = self.weight_config.get("sym", self.sym_default) + tmp_weight_config[name]["act_order"] = self.weight_config.get("act_order", self.act_order_default) + tmp_weight_config[name]["perchannel"] = self.weight_config.get("perchannel", self.perchannel_default) + tmp_weight_config[name]["mse"] = self.weight_config.get("mse", self.mse_default) + self.weight_config = tmp_weight_config + else: + for layer_name, config in self.weight_config.items(): + self.weight_config[layer_name]["wbits"] = config.get("wbits", self.wbits_default) + self.weight_config[layer_name]["group_size"] = config.get("group_size", self.group_size_default) + self.weight_config[layer_name]["block_size"] = config.get("block_size", self.group_size_default) + self.weight_config[layer_name]["percdamp"] = config.get("pecdamp", self.percdamp_default) + self.weight_config[layer_name]["sym"] = config.get("sym", self.sym_default) + self.weight_config[layer_name]["act_order"] = config.get("act_order", self.act_order_default) + self.weight_config[layer_name]["perchannel"] = config.get("perchannel", self.perchannel_default) + self.weight_config[layer_name]["mse"] = config.get("mse", self.mse_default) + + def get_layer_config(self, layer_name): + """Obtain config for one layer, since GPTQ supports layer-wise config.""" + # First try the exact name matching, if cannot find, use re to search. For example, can support ".*" in op_name + config = None + config = self.weight_config.get(layer_name, None) + if config is not None: + return config + else: + for k, v in self.weight_config.items(): + regex = re.compile(k) + if len(regex.findall(layer_name)) is not None: + config = v + return config + else: + pass + return config + + def track_hidden_states(self, data): + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, tuple) or isinstance(data, list): + return data[0] + + @torch.no_grad() + def pre_quantization(self): + """Prepare input calibration data and other attributes which are critical for gptq execution.""" + + # critical: hooker function which collects inputs + def forward(layer, *args, **kwargs): + # inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1 + self.cache_key_arguments["i"] += 1 + for arg in kwargs: + # TODO: investigate include parameters + # each outputs can be different shape, hence also use list to store + if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi": + if self.cache_key_arguments.get(arg, None) is None: + self.cache_key_arguments[arg] = [] + self.cache_key_arguments[arg].append(kwargs[arg]) + continue + # copy positional arguments, positional arguments are sensitive for their order, be cautious! + # Most models in HF has avoid this, but some models still use positional arguments other than + # hidden_states, chatglm2-6b etc. + for idx, item in enumerate(args): + if (idx + 1) > len(self.cache_positional_arguments): + # initialize + self.cache_positional_arguments.append([]) + self.cache_positional_arguments[idx].append(item) + raise ValueError + + # Step1: fetch the embeddings and other layers before the transformer stack. + if not self.layer_wise: + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer = embedding_layer.to(self.device) + + # Step2: modify the first transformer block's forward function to obtain inputs for calibration + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) + forward_cache = self.gptq_related_blocks["transformers"][0].forward + self.gptq_related_blocks["transformers"][0].forward = partial( + forward, self.gptq_related_blocks["transformers"][0] + ) + + # Step3: run forward to obtain calibration datasets + logger.info("Collecting calibration inputs...") + logger.info("Collecting calibration inputs by running the run_fn provided by user.") + if self.run_args: + self.run_fn(self.model, self.run_args) + else: + self.run_fn(self.model) + + # for batch in tqdm(self.dataloader): + # if not self.layer_wise: + # batch = move_input_to_device(batch, self.device) + # try: + # if isinstance(batch, tuple) or isinstance(batch, list): + # self.model(batch[0]) + # elif isinstance(batch, dict): + # self.model(**batch) + # else: + # self.model(batch) + # except ValueError: + # pass + # output inp data shape + logger.info("All calibration data's shape =>") + # check all hidden_states shape + try: + for hidden_states in self.cache_positional_arguments[0]: + logger.info(hidden_states.shape) + except: + pass + logger.info("Done.") + + # Step 4: restore original forward function, relocate layers back to cpu. + self.gptq_related_blocks["transformers"][0].forward = forward_cache + if not self.layer_wise: + self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() + for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): + embedding_layer.to(self.device) + torch.cuda.empty_cache() + # end + logger.info("GPTQ quantization prepared.") + + def gather_single_batch_from_dict(self, data_dict, idx): + # obtain a set of keyword input from cache + single_batch = {} + for k, v in data_dict.items(): + single_batch[k] = data_dict[k][idx] + return single_batch + + def gather_single_batch_from_list(self, data_list, idx): + # obtain a set of keyword input from cache + single_batch = [] + for data_item in data_list: + single_batch.append(data_item[idx]) + return single_batch + + def update_blockwise_hidden_states(self, outs): + if "hidden_states" in self.cache_key_arguments: + self.cache_key_arguments["hidden_states"] = outs[:] + else: + self.cache_positional_arguments[0] = outs[:] + + @torch.no_grad() + def execute_quantization(self, means=None, stds=None, model_path=None): + """Run quantization.""" + # Step1: prepare quantization (calibration datasets) + + logger.info("Begin ====>") + self.pre_quantization() + + # Step2: run gptq quantization in a transformer block-wise manner. + gptq_config = {} + tblock_length = len(self.gptq_related_blocks["transformers"]) + for block_idx in range(tblock_length): + logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") + # if we do not apply layer-wise feature, we still place the entire block on the GPU + transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. + sub_layers = find_layers(transformer_block) + sub_layers_to_quant = {} + for layer_name, layer_obj in sub_layers.items(): + # filter sub_layers with included layer_names in self.weight_config + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + # if self.weight_config.get(full_layer_name, None) == None: + if self.get_layer_config(full_layer_name) is None: + logger.warning(f"{full_layer_name} can be quantized " + "but excluded from quantization configs.") + else: + sub_layers_to_quant[layer_name] = layer_obj + del sub_layers + sub_layers = sub_layers_to_quant + # Step 2.2: Initialize GPTQ quantizers for collected layers. + gptq_for_this_block = {} + # initialize gptq quantizer for every layer in a transformer block + for layer_name in sub_layers: + # weight_config_this_layer = self.weight_config.get( + # self.get_full_layer_name(layer_name, block_idx), None + # ) + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + weight_config_this_layer = self.get_layer_config(full_layer_name) + W = sub_layers[layer_name].weight.data.clone() + gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) + # gptq_for_this_block[layer_name].quantizer = Quantizer() + gptq_for_this_block[layer_name].quantizer.configure( + weight_config_this_layer["wbits"], + weight_config_this_layer["perchannel"], + weight_config_this_layer["sym"], + weight_config_this_layer["mse"], + ) + + # Step 2.3: modify forward functions to hook inputs data (used in gptq execution) + def add_batch(_name): + def tmp(_, inp, out): + gptq_for_this_block[_name].add_batch(inp[0].data, out.data) # noqa: F821 + + return tmp + + handles = [] # register handles which add inputs and outputs to gptq object + for layer_name in sub_layers: + handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name))) + idx = self.cache_key_arguments.pop("i") + for j in range(self.dataloader_len): + cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) + cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) + out = transformer_block(*cache_positional_batch, **cache_keyword_batch) + out = self.track_hidden_states(out) + self.cache_key_arguments["i"] = idx + for h in handles: + h.remove() + # Step 2.4: everything is prepared, so start quantization! + for layer_name in sub_layers: + # weight_config_this_layer = self.weight_config.get( + # self.get_full_layer_name(layer_name, block_idx), None + # ) + weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) + logger.info(f"Quantizing layer {layer_name}") + W = sub_layers[layer_name].weight.data.clone() + scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( + W, + blocksize=weight_config_this_layer["block_size"], + percdamp=weight_config_this_layer["percdamp"], + groupsize=weight_config_this_layer["group_size"], + act_order=weight_config_this_layer["act_order"], + ) + sub_layers[layer_name].weight.data = Q + gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} + if not weight_config_this_layer["sym"]: + gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp + if weight_config_this_layer["act_order"]: # save perm for restoring the weights + gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"] = gptq_for_this_block[ + layer_name + ].perm + gptq_for_this_block[layer_name].free() + + # Step 2.5: replace output data with quantized weights + outs = [] + idx = self.cache_key_arguments.pop("i") + for j in range(self.dataloader_len): + cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) + cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) + out = transformer_block(*cache_positional_batch, **cache_keyword_batch) + out = self.track_hidden_states(out) + outs.append(out) + self.cache_key_arguments["i"] = idx + self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() + del gptq_for_this_block + torch.cuda.empty_cache() + # iteratively replace the input with output, thus layerwise quantization can continue. + self.update_blockwise_hidden_states(outs) + logger.info("------------------------------") + + logger.info("Quantization done") + # self.model.config.use_cache = self.use_cache + + # obtain model (all weight only quantization API function should return) + for k, v in gptq_config.items(): + for m, n in v.items(): + gptq_config[k][m] = n.tolist() + return self.model, gptq_config + + +class GPTQ: + """ + Please refer to: + GPTQ: Accurate Post-training Compression for Generative Pretrained Transformers (https://arxiv.org/abs/2210.17323) + """ + + def __init__(self, layer, W, device="cpu"): + self.layer = layer + self.device = device + # W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] # output channels + self.columns = W.shape[1] # input channels + self.H = torch.zeros((self.columns, self.columns), device=self.device) + self.nsamples = 0 + self.quantizer = Quantizer() + self.perm = None # act_order choice + + def add_batch(self, inp, out): + # if DEBUG: + # self.inp1 = inp + # self.out1 = out + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + # TODO: llm's transformer sequential with nn.conv2d is currently not under test + # if isinstance(self.layer, nn.Conv2d): + # unfold = nn.Unfold( + # self.layer.kernel_size, + # dilation=self.layer.dilation, + # padding=self.layer.padding, + # stride=self.layer.stride + # ) + # inp = unfold(inp) + # inp = inp.permute([1, 0, 2]) + # inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix + + def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): + # W = self.layer.weight.data.clone() + weight_shape, weight_dtype = W.shape, W.data.dtype + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 # such channel makes no contribution to quantization computation + + # rearrange considering the diag's value + if act_order: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + self.perm = perm.clone() + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.device) + H[diag, diag] += damp # add a average value of + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + scale = [] + zero = [] + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): # within a block, channel wise + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True) + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + + q = quantize(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + # if DEBUG: + # self.layer.weight.data[:, :i2] = Q[:, :i2] + # self.layer.weight.data[:, i2:] = W[:, i2:] + # logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") + # logger.info(f"{torch.sum(Losses)}") + + if str(self.device).startswith("cuda"): + torch.cuda.synchronize() + logger.info(f"time {(time.time() - tick)}") + logger.info(f"error {torch.sum(Losses).item()}") + + if act_order: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + # self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + Q = Q.reshape(weight_shape).to(weight_dtype) + if DEBUG: + logger.info(f"{torch.sum((self.layer(self.inp1) - self.out1) ** 2)}") + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + return scale, zero, Q + + def free(self): + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() + + +class Quantizer(nn.Module): + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) + + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8, trits=False): + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + # def quantize(self, x): + # if self.ready(): + # return quantize(x, self.scale, self.zero, self.maxq) + # return x + + # def enabled(self): + # return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +# TODO (Yi) remove it after unifying the algo config parser +from typing import Callable, Dict, Tuple + +from neural_compressor.torch.quantization.config import GPTQConfig + + +def gptq_config_mapping(configs_mapping: Dict[Tuple[str, Callable], GPTQConfig]): + # convert GPTQ_CONFIG to gptq_quantize's weight config + # convert tune_cfg to gptq_quantize's weight config + # for layer_wise quant mode + # TODO (Yi) uncomment it when port layer-wise + # if recipe_cfgs.get("layer_wise_quant", False): + # layer_wise = True + # from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks + + # os.makedirs(LWQ_WORKSPACE, exist_ok=True) + # # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) + # model_path = model.path + # assert model_path, "model_path should not be None." + # model_path = _get_path(model_path) + # lwq_handles = register_weight_hooks( + # model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE + # ) + + weight_config = {} + for (op_type, op_name), op_config in configs_mapping.items(): + if op_config.weight_dtype == "fp32": + continue + else: + weight_config[op_name] = { + "wbits": op_config.weight_bits, + "group_size": op_config.weight_group_size, + "sym": op_config.weight_sym, + "percdamp": op_config.percdamp, + "act_order": op_config.act_order, + "block_size": op_config.block_size, + "mse": op_config.enable_mse_search, + } + nsamples = op_config.nsamples + dataloader_len = op_config.dataloader_len + use_max_length = op_config.use_max_length + pad_max_length = op_config.pad_max_length + device = op_config.device + + if use_max_length and op_config.pad_max_length == 2048: + logger.warning( + "You choose to use unified sequence length for calibration, \ + but you have not set length value. Default sequence length is 2048 and this might cause inference error!" + ) + + return weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len + + +def apply_gptq_quantize(model, configs_mapping, *args, **kwargs): + """Apply gptq.""" + # TODO: unify weight_config keys, add docstring, and support default config + weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len = gptq_config_mapping( + configs_mapping + ) + assert isinstance(model, torch.nn.Module), "only support torch module" + # TODO (Yi) disable layer-wise and model_path first + layer_wise = False + model_path = None + + gptq_quantizer = GPTQuantizer( + model, + weight_config, + nsamples, + dataloader_len, + use_max_length, + pad_max_length, + device, + layer_wise=layer_wise, + *args, + **kwargs, + ) + fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) + logger.info("GPTQ quantization done.") + return fp32_modified_model, gptq_config + + +class DataloaderPreprocessor: + def __init__(self, dataloader_original, use_max_length=False, pad_max_length=2048, nsamples=128) -> None: + self.dataloader_original = dataloader_original + self.use_max_length = use_max_length + self.pad_max_length = pad_max_length + self.nsamples = nsamples + self.dataloader = [] + self.is_ready = False + + def get_prepared_dataloader(self): + if not self.is_ready: + self.prepare_dataloader() + return self.dataloader + + def prepare_dataloader(self): + if self.use_max_length: + # (Recommend) only take sequence whose length exceeds self.pad_max_length, + # which preserves calibration's tokens are all valid + # This is GPTQ official dataloader implementation + self.obtain_first_n_samples_fulllength() + else: + # general selection, no padding, not GPTQ original implementation. + self.obtain_first_n_samples() + self.is_ready = True + + def obtain_first_n_samples(self, seed=0): + """Get first nsample data as the real calibration dataset.""" + self.dataloader.clear() + random.seed(seed) + for batch in self.dataloader_original: + # process data, depends on its data type. + if len(self.dataloader) == self.nsamples: + logger.info(f"Successfully collect {self.nsamples} calibration samples.") + break + # list, tuple + if isinstance(batch, list) or isinstance(batch, tuple): + if batch[0].shape[-1] > self.pad_max_length: + i = random.randint(0, batch[0].shape[-1] - self.pad_max_length - 1) + j = i + self.pad_max_length + batch_final = [] + for item in batch: + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: + batch_final.append(item[:, i:j]) + else: + batch_final.append(item) + else: + batch_final = batch[:] + # dict + elif isinstance(batch, dict): + try: + length = batch["input_ids"].shape[-1] + except: + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") + continue + batch_final = {} + if length > self.pad_max_length: + i = random.randint(0, length - self.pad_max_length - 1) + j = i + self.pad_max_length + # may have to slice every sequence related data + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim + else: + batch_final[key] = batch[key] + else: + batch_final = batch + # tensor + else: + if batch.shape[-1] > self.pad_max_length: + i = random.randint(0, batch.shape[-1] - self.pad_max_length - 1) + j = i + self.pad_max_length + batch_final = batch[:, i:j] + else: + batch_final = batch + self.dataloader.append(batch_final) + + if len(self.dataloader) < self.nsamples: + logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.") + + def obtain_first_n_samples_fulllength(self, seed=0): + self.dataloader.clear() + random.seed(seed) + unified_length = self.pad_max_length + for batch in self.dataloader_original: + if len(self.dataloader) == self.nsamples: + logger.info(f"Successfully collect {self.nsamples} calibration samples.") + break + # list & tuple, gpt-j-6b mlperf, etc. + if isinstance(batch, list) or isinstance(batch, tuple): + if batch[0].shape[-1] == unified_length: + batch_final = batch[:] + elif batch[0].shape[-1] > unified_length: + i = random.randint(0, batch[0].shape[-1] - unified_length - 1) + j = i + unified_length + batch_final = [] + for item in batch: + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: + batch_final.append(item[:, i:j]) + else: + batch_final.append(item) + else: + # not match max length, not include in target dataset + continue + # dict + elif isinstance(batch, dict): + try: + length = batch["input_ids"].shape[-1] + except: + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") + continue + batch_final = {} + if length == self.pad_max_length: + batch_final = batch + elif length > self.pad_max_length: + i = random.randint(0, length - self.pad_max_length - 1) + j = i + self.pad_max_length + # may have to slice every sequence related data + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position + else: + batch_final[key] = batch[key] + else: + # not match max length, not include in target dataset + continue + # tensor + else: + if batch.shape[-1] == unified_length: + batch_final = batch + elif batch.shape[-1] > unified_length: + i = random.randint(0, batch.shape[-1] - unified_length - 1) + j = i + unified_length + batch_final = batch[:, i:j] + else: + # not match max length, not include in target dataset + continue + self.dataloader.append(batch_final) + if len(self.dataloader) < self.nsamples: # pragma: no cover + logger.warning( + f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \ + but only {len(self.dataloader)} samples are found. Please use smaller 'self.pad_max_length' value." + ) diff --git a/neural_compressor/torch/algorithms/rtn.py b/neural_compressor/torch/algorithms/rtn.py index 1c0071d99e9..a6eb9c779af 100644 --- a/neural_compressor/torch/algorithms/rtn.py +++ b/neural_compressor/torch/algorithms/rtn.py @@ -658,3 +658,29 @@ def rtn_quantize( q_weight = q_weight.T if group_dim == 0 else q_weight m.weight.data.copy_(q_weight) return model + + +from neural_compressor.torch.quantization.config import RTNWeightQuantConfig + + +def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module: + # TODO (Yi) remove it + enable_full_range = quant_config.enable_full_range + enable_mse_search = quant_config.enable_mse_search + group_dim = quant_config.group_dim + dtype = quant_config.weight_dtype + num_bits = quant_config.weight_bits + scheme = "sym" if quant_config.weight_sym else "asym" + group_size = quant_config.weight_group_size + return_int = quant_config.return_int + return rtn_quantize( + module, + num_bits, + group_size, + scheme, + return_int=return_int, + data_type=dtype, + enable_full_range=enable_full_range, + enable_mse_search=enable_mse_search, + group_dim=group_dim, + ) diff --git a/neural_compressor/torch/algorithms/rtn_quantize.py b/neural_compressor/torch/algorithms/rtn_quantize.py deleted file mode 100644 index 55e9fd31f4d..00000000000 --- a/neural_compressor/torch/algorithms/rtn_quantize.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Dict - -import torch - -from neural_compressor.common.base_config import BaseConfig -from neural_compressor.common.logger import Logger -from neural_compressor.common.utility import RTN_WEIGHT_ONLY_QUANT -from neural_compressor.torch.algorithms.rtn import rtn_quantize as torch_rtn_quantize -from neural_compressor.torch.quantization.config import RTNWeightQuantConfig -from neural_compressor.torch.utils import fetch_module, register_algo, set_module - -logger = Logger().get_logger() - - -def _apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module: - enable_full_range = quant_config.enable_full_range - enable_mse_search = quant_config.enable_mse_search - group_dim = quant_config.group_dim - dtype = quant_config.weight_dtype - num_bits = quant_config.weight_bits - scheme = "sym" if quant_config.weight_sym else "asym" - group_size = quant_config.weight_group_size - return_int = quant_config.return_int - return torch_rtn_quantize( - module, - num_bits, - group_size, - scheme, - return_int=return_int, - data_type=dtype, - enable_full_range=enable_full_range, - enable_mse_search=enable_mse_search, - group_dim=group_dim, - ) - - -def _convert_quant_config_into_quant_config_mapping( - fp32_model: torch.nn.Module, quant_config: BaseConfig -) -> Dict[str, BaseConfig]: - # TODO(Yi) enhance it, currently we only assign the global config to module - # model_info: List[Tuple[str, Callable]] = [] - linear_lst = [] - for name, module in fp32_model.named_modules(): - if isinstance(module, torch.nn.Linear): - linear_lst.append(name) - _quant_config = quant_config if quant_config.global_config is None else quant_config.global_config - quant_config_mapping: Dict[str, BaseConfig] = {name: _quant_config for name in linear_lst} - return quant_config_mapping - - -@register_algo(name=RTN_WEIGHT_ONLY_QUANT) -def rtn_quantize_entry(model: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module: - quant_config_mapping: Dict[str, RTNWeightQuantConfig] = _convert_quant_config_into_quant_config_mapping( - model, quant_config - ) - """The main entry to apply rtn quantization.""" - for op_name, quant_config in quant_config_mapping.items(): - original_module = fetch_module(model, op_name) - logger.info(f"Apply RTN on module: {op_name}, {original_module}") - rtn_module = _apply_rtn_on_single_module(original_module, quant_config) - set_module(model, op_name, rtn_module) - return model diff --git a/neural_compressor/torch/algorithms/weight_only_algos.py b/neural_compressor/torch/algorithms/weight_only_algos.py new file mode 100644 index 00000000000..dd07c0d1494 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only_algos.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, Tuple + +import torch + +from neural_compressor.common.logger import Logger +from neural_compressor.common.utility import GPTQ, RTN_WEIGHT_ONLY_QUANT +from neural_compressor.torch.quantization.config import GPTQConfig, RTNWeightQuantConfig +from neural_compressor.torch.utils import fetch_module, register_algo, set_module + +logger = Logger().get_logger() + + +###################### RTN Algo Entry ################################## +@register_algo(name=RTN_WEIGHT_ONLY_QUANT) +def rtn_quantize_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNWeightQuantConfig], *args, **kwargs +) -> torch.nn.Module: + """The main entry to apply rtn quantization.""" + from neural_compressor.torch.algorithms.rtn import apply_rtn_on_single_module + + for (op_type, op_name), quant_config in configs_mapping.items(): + original_module = fetch_module(model, op_name) + logger.info(f"Apply RTN on module: {op_name}, {original_module}") + rtn_module = apply_rtn_on_single_module(original_module, quant_config) + set_module(model, op_name, rtn_module) + return model + + +###################### GPTQ Algo Entry ################################## +@register_algo(name=GPTQ) +def gptq_quantize_entry( + model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs +) -> torch.nn.Module: + logger.info("Quantize model with the GPTQ algorithm.") + from neural_compressor.torch.algorithms.gptq import apply_gptq_quantize + + model, quantization_perm = apply_gptq_quantize(model=model, configs_mapping=configs_mapping, *args, **kwargs) + # Assign the gptq config as an attribute of model + model._gptq_quantization_perm = quantization_perm + return model diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index e159bf99bad..24235271dae 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -18,4 +18,6 @@ get_default_rtn_config, DummyConfig, get_default_dummy_config, + GPTQConfig, + get_default_gptq_config, ) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 16de62fab36..f492baa8517 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -23,7 +23,7 @@ import torch from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs -from neural_compressor.common.utility import DUMMY_CONFIG, RTN_WEIGHT_ONLY_QUANT +from neural_compressor.common.utility import DUMMY_CONFIG, GPTQ, RTN_WEIGHT_ONLY_QUANT FRAMEWORK_NAME = "torch" @@ -47,6 +47,9 @@ class OperatorConfig(NamedTuple): str2operator = {"Linear": torch.nn.Linear, "linear": torch.nn.functional.linear, "Conv2d": torch.nn.Conv2d} +######################## RNT Config ############################### + + @register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN_WEIGHT_ONLY_QUANT) class RTNWeightQuantConfig(BaseConfig): """Config class for round-to-nearest weight-only quantization.""" @@ -139,6 +142,115 @@ def get_default_rtn_config() -> RTNWeightQuantConfig: return RTNWeightQuantConfig() +######################## GPTQ Config ############################### +@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ) +class GPTQConfig(BaseConfig): + """Config class for GPTQ. + + GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers. + https://arxiv.org/abs/2210.17323 + """ + + name = GPTQ + supported_configs: List[OperatorConfig] = [] + params_list = [ + "weight_dtype", + "weight_bits", + "weight_group_size", + "weight_sym", + "block_size", + "act_dtype", + "group_dim", + "nsamples", + "dataloader_len", + "percdamp", + "act_order", + "use_max_length", + "pad_max_length", + "enable_mse_search", + "device", + "layer_wise", + "return_int", + ] + + def __init__( + self, + weight_dtype: str = "int", + weight_bits: int = 4, + weight_group_size: int = 32, + weight_sym: bool = True, + block_size: int = 128, + act_dtype: str = "fp32", + group_dim: int = 1, + nsamples: int = 128, + dataloader_len: int = 10, + percdamp: float = 0.01, + act_order: bool = False, + use_max_length: bool = True, + pad_max_length: int = 2048, + enable_mse_search: bool = False, + device=None, + layer_wise: bool = False, + return_int: bool = False, + ): + """Init GPTQ config. + + Args: + """ + super().__init__() + self.weight_dtype = weight_dtype + self.weight_bits = weight_bits + self.weight_group_size = weight_group_size + self.weight_sym = weight_sym + self.act_dtype = act_dtype + self.block_size = block_size + self.enable_mse_search = enable_mse_search + self.group_dim = group_dim + self.nsamples = nsamples + # TODO(Yi) detect it auto + self.dataloader_len = dataloader_len + self.percdamp = percdamp + self.act_order = act_order + self.use_max_length = use_max_length + self.pad_max_length = pad_max_length + self.layer_wise = layer_wise + self.device = device + self.return_int = return_int + + def to_dict(self): + return super().to_dict(params_list=self.params_list, operator2str=operator2str) + + @classmethod + def from_dict(cls, config_dict): + return super(GPTQConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator) + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + # TODO(Yi) + linear_gptq_config = GPTQConfig() + operators = [torch.nn.Linear, torch.nn.functional.linear] + supported_configs.append( + OperatorConfig(config=linear_gptq_config, operators=operators, backend=Backend.DEFAULT) + ) + cls.supported_configs = supported_configs + + +# TODO(Yi) run `register_supported_configs` for all registered config. +GPTQConfig.register_supported_configs() + + +def get_default_gptq_config() -> GPTQConfig: + """Generate the default gptq config. + + Returns: + the default gptq config. + """ + return GPTQConfig() + + +######################## Dummy Config ############################### +# TODO (Yi) remove it after finishing the GPTQ config @register_config(framework_name=FRAMEWORK_NAME, algo_name=DUMMY_CONFIG) class DummyConfig(BaseConfig): """Config class for round-to-nearest weight-only quantization.""" @@ -201,12 +313,3 @@ def get_default_dummy_config() -> DummyConfig: def get_all_registered_configs() -> Dict[str, BaseConfig]: return registered_configs.get(FRAMEWORK_NAME, {}) - - -def parse_config_from_dict(config_dict: Dict) -> BaseConfig: - torch_registered_configs = get_all_registered_configs() - for key, val in config_dict.items(): - if key in torch_registered_configs: - config = torch_registered_configs[key].from_dict(val) - return config - # TODO(Yi) parse multiple configs after support configs add diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index e53023ac363..90744385b95 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -12,46 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable +from typing import Any, Callable, Dict, Tuple import torch -from neural_compressor.common.base_config import BaseConfig +from neural_compressor.common.base_config import BaseConfig, ComposableConfig, registered_configs from neural_compressor.common.logger import Logger -from neural_compressor.common.utility import RTN_WEIGHT_ONLY_QUANT -from neural_compressor.torch.quantization.config import parse_config_from_dict -from neural_compressor.torch.utils import algos_mapping +from neural_compressor.torch.quantization.config import FRAMEWORK_NAME +from neural_compressor.torch.utils import WHITE_MODULE_LIST, algos_mapping, get_model_info logger = Logger().get_logger() +def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_name): + return any(config.name == algo_name for config in configs_mapping.values()) + + def quantize( - model: torch.nn.Module, quant_config: BaseConfig, calib_func: Callable = None, calib_func_arg: Any = None + model: torch.nn.Module, + quant_config: BaseConfig, + run_fn: Callable = None, + run_args: Any = None, ) -> torch.nn.Module: """The main entry to quantize model. Args: model: a float model to be quantized. quant_config: a quantization configuration. - calib_func: a calibration function for calibrating the model. Defaults to None. - calib_func_arg: positional arguments for `calib_func`. Defaults to None. + run_fn: a calibration function for calibrating the model. Defaults to None. + run_args: positional arguments for `run_fn`. Defaults to None. Returns: The quantized model. """ if isinstance(quant_config, dict): - quant_config = parse_config_from_dict(quant_config) - logger.info("Parsed dict to construct the quantization config.") + quant_config = ComposableConfig.from_dict(quant_config, config_registry=registered_configs[FRAMEWORK_NAME]) + logger.info(f"Parsed a config dict to construct the quantization config: {quant_config}.") else: assert isinstance( quant_config, BaseConfig ), "Please pass a dict or config instance as the quantization configuration." logger.info(f"Quantize model with config: \n {quant_config.to_json_string()} \n") # select quantization algo according to config - # TODO (Yi) support combine more than one algo - if quant_config.name == RTN_WEIGHT_ONLY_QUANT: - quant_fn = algos_mapping[quant_config.name] - else: - raise NotImplementedError("Currently, only the rtn algorithm is being ported.") - qmodel = quant_fn(model, quant_config) - return qmodel + + model_info = get_model_info(model=model, white_module_list=WHITE_MODULE_LIST) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.debug(configs_mapping) + for algo_name, algo_func in algos_mapping.items(): + if need_apply(configs_mapping, algo_name): + logger.info(f"Start to apply {algo_name} on the model.") + model = algo_func(model, configs_mapping, run_fn=run_fn, run_args=run_args) + return model diff --git a/neural_compressor/torch/utils.py b/neural_compressor/torch/utils.py index 134bb14797c..289a488ef86 100644 --- a/neural_compressor/torch/utils.py +++ b/neural_compressor/torch/utils.py @@ -24,6 +24,9 @@ import torch +# All constants for torch +WHITE_MODULE_LIST = [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] + def register_algo(name): """Decorator function to register algorithms in the algos_mapping dictionary. @@ -102,4 +105,5 @@ def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> if pair not in filter_result_set: filter_result_set.add(pair) filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") return filter_result diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index e366873eaea..dd210f11bfa 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -119,7 +119,7 @@ def test_config_from_dict(self): }, } } - config = RTNWeightQuantConfig.from_dict(quant_config) + config = RTNWeightQuantConfig.from_dict(quant_config["rtn_weight_only_quant"]) self.assertIsNotNone(config.local_config) def test_config_to_dict(self): @@ -219,8 +219,18 @@ def test_config_mapping(self): logger.info(quant_config) configs_mapping = quant_config.to_config_mapping(model_info=model_info) logger.info(configs_mapping) - self.assertTrue(configs_mapping[torch.nn.Linear]["fc1"].weight_bits == 6) - self.assertTrue(configs_mapping[torch.nn.Linear]["fc2"].weight_bits == 4) + self.assertTrue(configs_mapping[(torch.nn.Linear, "fc1")].weight_bits == 6) + self.assertTrue(configs_mapping[(torch.nn.Linear, "fc2")].weight_bits == 4) + + def test_gptq_config(self): + from neural_compressor.torch.quantization import GPTQConfig + + gptq_config1 = GPTQConfig(weight_bits=8, pad_max_length=512) + quant_config_dict = { + "gptq": {"weight_bits": 8, "pad_max_length": 512}, + } + gptq_config2 = GPTQConfig.from_dict(quant_config_dict["gptq"]) + self.assertEqual(gptq_config1.to_dict(), gptq_config2.to_dict()) if __name__ == "__main__": diff --git a/test/3x/torch/test_gptq_algo.py b/test/3x/torch/test_gptq_algo.py new file mode 100644 index 00000000000..edf35626b81 --- /dev/null +++ b/test/3x/torch/test_gptq_algo.py @@ -0,0 +1,285 @@ +import unittest + +import torch + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + + +def get_gpt_j(): + import transformers + + tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + torchscript=True, + ) + return tiny_gptj + + +class GPTQLLMDataLoader: + def __init__(self, length=512): + self.batch_size = 1 + self.length = length + + def __iter__(self): + for i in range(10): + yield torch.ones([1, self.length], dtype=torch.long) + + +class GPTQLLMDataLoaderList(GPTQLLMDataLoader): + def __iter__(self): + for i in range(10): + yield (torch.ones([1, self.length], dtype=torch.long), torch.ones([1, self.length], dtype=torch.long)) + + +class GPTQLLMDataLoaderDict(GPTQLLMDataLoader): + def __iter__(self): + for i in range(10): + yield { + "input_ids": torch.ones([1, self.length], dtype=torch.long), + "attention_mask": torch.ones([1, self.length], dtype=torch.long), + } + + +from tqdm import tqdm + +from neural_compressor.torch.algorithms.gptq import move_input_to_device + + +def run_fn_for_gptq(model, dataloader_for_calibration, *args): + logger.info("Collecting calibration inputs...") + for batch in tqdm(dataloader_for_calibration): + batch = move_input_to_device(batch, device=None) + try: + if isinstance(batch, tuple) or isinstance(batch, list): + model(batch[0]) + elif isinstance(batch, dict): + model(**batch) + else: + model(batch) + except ValueError: + pass + return + + +class TestGPTQ(unittest.TestCase): + @classmethod + def setUpClass(self): + pass + + @classmethod + def tearDownClass(self): + pass + + def setUp(self): + # print the test name + logger.info(f"Running TestGPTQ test: {self.id()}") + + def test_gptq(self): + # Ported from test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py + # TestPytorchWeightOnlyAdaptor.test_GPTQ_fixed_length_quant + from neural_compressor.torch import GPTQConfig, quantize + + dataloader = GPTQLLMDataLoader() + + # case 1: tensor + model_1 = get_gpt_j() + input = torch.ones([1, 512], dtype=torch.long) + out0 = model_1(input) + device = None + from neural_compressor.torch.algorithms.gptq import DataloaderPreprocessor + + dataloaderPreprocessor = DataloaderPreprocessor( + dataloader_original=dataloader, use_max_length=False, pad_max_length=512, nsamples=128 + ) + dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() + + quant_config = GPTQConfig( + weight_group_size=8, dataloader_len=len(dataloader_for_calibration), pad_max_length=512 + ) + quant_config.set_local("lm_head", GPTQConfig(weight_dtype="fp32")) + logger.info(f"Test GPTQ with config {quant_config}") + q_model = quantize( + model=model_1, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration + ) + out1 = q_model(input) + self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) + + def test_gptq_advance(self): + # Ported from test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py + # TestPytorchWeightOnlyAdaptor.test_GPTQ_fixed_length_quant + from neural_compressor.torch import GPTQConfig, quantize + + dataloader = GPTQLLMDataLoader() + model_1 = get_gpt_j() + input = torch.ones([1, 512], dtype=torch.long) + out0 = model_1(input) + + device = None + from neural_compressor.torch.algorithms.gptq import DataloaderPreprocessor + + dataloaderPreprocessor = DataloaderPreprocessor( + dataloader_original=dataloader, use_max_length=False, pad_max_length=512, nsamples=128 + ) + dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() + + quant_config = GPTQConfig( + weight_group_size=8, + dataloader_len=len(dataloader_for_calibration), + act_order=True, + enable_mse_search=True, + pad_max_length=512, + ) + quant_config.set_local("lm_head", GPTQConfig(weight_dtype="fp32")) + logger.info(f"Test GPTQ with config {quant_config}") + q_model = quantize( + model=model_1, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration + ) + out1 = q_model(input) + self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) + + def _apply_gptq(self, input, model, quant_config, run_fn, run_args): + logger.info(f"Test GPTQ with config {quant_config}") + from neural_compressor.torch import quantize + + out0 = model(input) + q_model = quantize(model=model, quant_config=quant_config, run_fn=run_fn, run_args=run_args) + out1 = q_model(input) + self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) + + def test_more_gptq(self): + import random + from itertools import product + + from neural_compressor.torch import GPTQConfig + + # some tests were skipped to accelerate the CI + input = torch.ones([1, 512], dtype=torch.long) + # dataloader + dataloader_collections = [GPTQLLMDataLoader, GPTQLLMDataLoaderList, GPTQLLMDataLoaderDict] + gptq_options = { + "weight_sym": [False, True], + "weight_group_size": [8], + "use_max_length": [False, True], + "pad_max_length": [512], + } + for dataloader_cls in dataloader_collections: + for value in product(*gptq_options.values()): + d = dict(zip(gptq_options.keys(), value)) + quant_config = GPTQConfig(**d) + length = 512 if quant_config.use_max_length else random.randint(1, 1024) + from neural_compressor.torch.algorithms.gptq import DataloaderPreprocessor + + dataloaderPreprocessor = DataloaderPreprocessor( + dataloader_original=dataloader_cls(length), + use_max_length=d["use_max_length"], + pad_max_length=d["pad_max_length"], + nsamples=128, + ) + dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() + quant_config.dataloader_len = len(dataloader_for_calibration) + + self._apply_gptq( + model=get_gpt_j(), + input=input, + quant_config=quant_config, + run_fn=run_fn_for_gptq, + run_args=dataloader_for_calibration, + ) + + def test_gptq_wbits(self): + import copy + import random + + class GPTQLLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(20): + length = random.randint(1, 1024) + yield torch.ones([1, length], dtype=torch.long) + + dataloader = GPTQLLMDataLoader() + model = copy.deepcopy(get_gpt_j()) + weight_config = { + "transformer.h.0.attn.k_proj": { + "wbits": 4, + "group_size": 128, + "sym": True, + "percdamp": 0.01, + "perchannel": False, + }, + "transformer.h.1.attn.k_proj": { + "wbits": 3, + "group_size": -1, + "sym": False, + "percdamp": 0.01, + "act_order": True, + }, + "transformer.h.2.attn.k_proj": { + "wbits": 3, + "group_size": 32, + "sym": False, + "percdamp": 0.01, + "mse": True, + "act_order": False, + }, + "transformer.h.3.attn.k_proj": { + "wbits": 3, + "group_size": 256, + "sym": False, + "percdamp": 0.01, + "mse": True, + "act_order": False, + }, + } + from neural_compressor.torch.algorithms.gptq import DataloaderPreprocessor + + dataloaderPreprocessor = DataloaderPreprocessor( + dataloader_original=dataloader, + use_max_length=True, + pad_max_length=512, + nsamples=128, + ) + preprocessed_dataloader = dataloaderPreprocessor.get_prepared_dataloader() + from neural_compressor.torch.algorithms.gptq import GPTQuantizer + + quantizer = GPTQuantizer( + model=model, + weight_config=weight_config, + dataloader_len=13, + use_max_length=True, + pad_max_length=512, + run_fn=run_fn_for_gptq, + run_args=preprocessed_dataloader, + ) + quantizer.execute_quantization() + self.assertTrue(isinstance(model, torch.nn.Module)) + self.gptj = get_gpt_j() + + model = copy.deepcopy(self.gptj) + weight_config = {"wbits": 4} + dataloaderPreprocessor = DataloaderPreprocessor( + dataloader_original=dataloader, + use_max_length=False, + pad_max_length=512, + nsamples=128, + ) + quantizer = GPTQuantizer( + model=model, + weight_config=weight_config, + dataloader_len=13, + use_max_length=False, + pad_max_length=512, + run_fn=run_fn_for_gptq, + run_args=preprocessed_dataloader, + ) + quantizer.execute_quantization() + preprocessed_dataloader = dataloaderPreprocessor.get_prepared_dataloader() + self.assertTrue(isinstance(model, torch.nn.Module)) + + +if __name__ == "__main__": + unittest.main()