diff --git a/auto_round/__main__.py b/auto_round/__main__.py index ec093ff2..53490a01 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -40,8 +40,19 @@ def run_mllm(): else: tune(args) +def run_lmms(): + from transformers.utils.versions import require_version + require_version("lmms_eval", "lmms_eval need to be installed, `pip install lmms_eval`") + # from auto_round.script.lmms_eval import setup_lmms_args, eval + from auto_round.script.mllm import setup_lmms_parser, lmms_eval + args = setup_lmms_parser() + lmms_eval(args) + def switch(): - if "--mllm" in sys.argv: + if "--lmms" in sys.argv: + sys.argv.remove("--lmms") + run_lmms() + elif "--mllm" in sys.argv: sys.argv.remove("--mllm") run_mllm() else: diff --git a/auto_round/auto_quantizer.py b/auto_round/auto_quantizer.py index 7cbcd2ea..46f41fc2 100644 --- a/auto_round/auto_quantizer.py +++ b/auto_round/auto_quantizer.py @@ -536,7 +536,7 @@ def remove_str(input_string: str, sub_str) -> str: ) if "gptq" in layer_backend and "exllamav2" in layer_backend: try: - from exllamav2_kernels import gemm_half_q_half, make_q_matrix + from exllamav2_kernels import gemm_half_q_half, make_q_matrix # pylint: disable=E0611 except: logger.warning_once( "For better inference performance, please install exllamav2 kernel " diff --git a/auto_round/mllm/README.md b/auto_round/mllm/README.md index 9060c168..cd342f70 100644 --- a/auto_round/mllm/README.md +++ b/auto_round/mllm/README.md @@ -1,4 +1,19 @@ # AutoRound for MLLMs +## Basic Usage (Gaudi2/CPU/GPU) +A user guide detailing the full list of supported arguments is provided by calling ```auto-round-mllm -h``` on the terminal.Alternatively, you can use ```auto_round_mllm``` instead of ```auto-round-mllm```. Set the format you want in `format` and +multiple formats exporting has been supported. + +```bash +# experimental feature, default hyperparameters may be changed later +auto—round-mllm \ + --model Qwen/Qwen2-VL-2B-Instruct\ + --bits 4 \ + --batch_size 1 \ + --gradient_accumulate_steps 4 \ + --group_size 128 \ + --format "auto_round" \ + --output_dir ./tmp_autoround +``` ## API Usage (Gaudi2/CPU/GPU) ```python from auto_round import AutoRoundMLLM @@ -21,7 +36,22 @@ output_dir = "./tmp_autoround" autoround.save_quantized(output_dir, format='auto_round', inplace=True) ``` -## Template +### Dataset +For mllm, we used liuhaotian/llava_conv_58k as our default calib datasets. Through argument ```--dataset```, user can use other datasets such as "liuhaotian/llava_instruct_80k", "liuhaotian/llava_instruct_150k" or a file path to use local file. + +### Support Matrix +So far, auto-round for mllm supports five model families, include Qwen2-VL, Llama-Vision, Phi3-Vision, Llava-v1.5 and CogVLM2. + +|Model |Eval Lib |calibration dataset|quant nontext module| +|---------------|-----------|-------------------|--------------------| +|Qwen2-VL |vlmeval |pile/llava |- | +|Llama-Vision |lmms_eval |llava |✔ | +|Phi3-Vision |vlmeval |pile/llava |✔ | +|Llava-v1.5 |lmms_eval |pile/llava |- | +|CogVLM2 |lmms_eval |pile/llava |✔ | + +## New Models Support +### Template For autoround MLLMs, using Template to customize different operations for different models. User can add a custom chat template through json file as below. ```json { @@ -33,7 +63,9 @@ For autoround MLLMs, using Template to customize different operations for differ "format_separator": "\n", "default_system": "You are a helpful assistant.", "replace_tokens": ["", "<|vision_start|><|image_pad|><|vision_end|>"], - "processor": "qwen2_vl" } + "extra_encode": "True", + "processor": "qwen2_vl" +} ``` The special token ```{{content}}``` is a placeholder to tell the preprocessor where to fill in the corresponding dialogue content. @@ -45,5 +77,5 @@ For example, the input conversations:
Using the above template, the input will be converted to the specified format required by Qwen2-vl as below:
```'<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>\nWhat are the colors of the bus in the image?<|im_end|>\n<|im_start|>assistant\nThe bus in the image is white and red.<|im_end|>\n<|im_start|>user\nWhat feature can be seen on the back of the bus?<|im_end|>\n<|im_start|>assistant\nThe back of the bus features an advertisement.<|im_end|>\n<|im_start|>user\nIs the bus driving down the street or pulled off to the side?<|im_end|>\n<|im_start|>assistant\nThe bus is driving down the street, which is crowded with people and other vehicles.<|im_end|>\n'```. -## Processor +### Processor Processor is callback interface for calling different processors, such as texts or images processors, for MLLMs. User can define own processor and use registration function to declare. For more information, please refer to the relevant code in ```auto_round/mllm/processor.py```. \ No newline at end of file diff --git a/auto_round/mllm/__init__.py b/auto_round/mllm/__init__.py index f42a4f48..41858319 100644 --- a/auto_round/mllm/__init__.py +++ b/auto_round/mllm/__init__.py @@ -16,4 +16,4 @@ from .template import Template, get_template, TEMPLATES from .autoround_mllm import AutoRoundMLLM from ..utils import LazyImport -from .eval import mllm_eval \ No newline at end of file +from .eval import mllm_eval, lmms_eval \ No newline at end of file diff --git a/auto_round/mllm/eval.py b/auto_round/mllm/eval.py index 58f8ba18..887286de 100644 --- a/auto_round/mllm/eval.py +++ b/auto_round/mllm/eval.py @@ -26,6 +26,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Copyright (c) 2024 LMMs-Lab + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + import os import time import json @@ -33,6 +53,8 @@ import pandas as pd from ..utils import logger, LazyImport +import numpy as np + vlmeval = LazyImport("vlmeval") @@ -284,3 +306,103 @@ def mllm_eval( continue rt_file.write('%d tasks cost: %.4fs\n' % (len(dataset), time.time() - st)) rt_file.close() + + +MODEL_TYPE_TO_LMMS_MODEL = { + # model_name + "Qwen-VL": "qwen_vl", + "Qwen2-VL": "qwen2_vl", + "cogvlm2": "cogvlm2", + "llava_v1.5": "llava", + "Llama-3.2": "llama_vision", + "Phi-3-vision": "phi3v", + "Phi-3.5-vision": "phi3v", + + # model_type + "qwen2_vl": "qwen2_vl", + "qwen": "qwen_vl", + "llava": "llava", + "phi3_v": "phi3v", + "mllama": "llama_vision", +} + +_lmms_eval = LazyImport("lmms_eval") + +def _handle_non_serializable(o): + if isinstance(o, np.int64) or isinstance(o, np.int32): + return int(o) + elif isinstance(o, set): + return list(o) + else: + return str(o) + +def lmms_eval( + model, + tasks, + output_dir = None, + num_fewshot=None, + limit=None, + batch_size=1, + max_batch_size=None, + device='cpu', + use_cache=None, + apply_chat_template=False + ): + from auto_round import AutoRoundConfig + + if isinstance(tasks, str): + tasks = tasks.replace(' ', '').split(',') + + model_name = model + if model_name[-1] == "/": + model_name = model_name[:-1] + model_name = model_name.split("/")[-1] + + model_type = None + split_name = model_name.split("-") + for i in range(len(split_name), 0, -1): + tmp = "-".join(split_name[0:i]) + if tmp in MODEL_TYPE_TO_LMMS_MODEL: + model_type = tmp + break + if model_type is None: + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model, trust_remote_code=True) + model_type = config.model_type + + assert model_type in MODEL_TYPE_TO_LMMS_MODEL, f"{model_type} is not support by lmms." + + if MODEL_TYPE_TO_LMMS_MODEL[model_type] == "phi3v": + model_args = f"model_id_name={model}" + else: + model_args = f"pretrained={model}" + if MODEL_TYPE_TO_LMMS_MODEL[model_type] == "llama_vision": + model_args += f",device_map={device}" + class CliArgs: + output_path = output_dir + + results = _lmms_eval.evaluator.simple_evaluate( + model=MODEL_TYPE_TO_LMMS_MODEL[model_type], + model_args=model_args, + tasks=tasks, + num_fewshot=num_fewshot, + limit=limit, + batch_size=batch_size, + max_batch_size=max_batch_size, + device=device, + use_cache=use_cache, + apply_chat_template=apply_chat_template, + cli_args=CliArgs() + ) + + # print and save result + print(_lmms_eval.utils.make_table(results)) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + from datetime import datetime + now = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = os.path.join(output_dir, f"{model_name}_{now}_result.json") + json.dump(results, open(output_file, 'w'), indent=4, default=_handle_non_serializable) + + return results diff --git a/auto_round/script/mllm.py b/auto_round/script/mllm.py index 9afcc1d4..bc6233e5 100644 --- a/auto_round/script/mllm.py +++ b/auto_round/script/mllm.py @@ -14,6 +14,7 @@ import os import argparse +import json import torch import transformers @@ -282,9 +283,9 @@ def tune(args): else: cls = AutoModelForCausalLM - model = cls.from_pretrained( - model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype, - device_map="auto" if use_auto_mapping else None) + model = cls.from_pretrained( + model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype, + device_map="auto" if use_auto_mapping else None) if "cogvlm2" in model_name: model.config.model_type = "cogvlm2" @@ -389,3 +390,69 @@ def eval(args): ignore=args.ignore ) +def setup_lmms_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", "--model_name", "--model_name_or_path", + help="model name or path") + parser.add_argument( + "--tasks", + default="pope,textvqa_val,scienceqa,mmbench_en", + help="To get full list of tasks, use the command lmms-eval --tasks list", + ) + parser.add_argument("--output_dir", default="./tmp_autoround", type=str, + help="the directory to save quantized model") + parser.add_argument( + "--num_fewshot", + type=int, + default=None, + help="Number of examples in few-shot context", + ) + parser.add_argument( + "--batch_size", + "-b", + type=str, + default=1, + metavar="auto|auto:N|N", + help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=None, + metavar="N", + help="Maximal batch size to try with --batch_size auto.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to use (e.g. cuda, cuda:0, cpu)", + ) + parser.add_argument( + "--limit", + type=float, + default=None, + help="Limit the number of examples per task. " "If <1, limit is a percentage of the total" + " number of examples.", + ) + args = parser.parse_args() + return args + +def lmms_eval(args): + from auto_round.mllm import lmms_eval + + results = lmms_eval( + model=args.model, + tasks=args.tasks, + output_dir=args.output_dir, + num_fewshot=args.num_fewshot, + limit=args.limit, + batch_size=args.batch_size, + max_batch_size=args.max_batch_size, + device=args.device, + use_cache=None, + apply_chat_template=False, + ) + return results + +