From 9a1ef937d90c2ec8f7503d4eef3b2a17607de9c4 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 15 Nov 2023 21:34:19 -0800 Subject: [PATCH] Continuous Model Delivery This PR provides a script that automatically quantizes models from HuggingFace using various quantization formats as specified. Example: When being provided the following JSON file: ```json { "destination": "{username}/{model_id}-{quantization}", # Name of HF repo "default_quantization": ["q0f16", "q0f32", "q3f16_1", "q4f16_1", "q4f32_1"], "tasks": [ { "model_id": "Llama-2-7b-hf", "model": "/models/Llama-2-7b-hf", # Can be HF URL or a local path "context_window_size": 4096, "conv_template": "LM", "quantization": [ { "format": "q4f16_awq", "model": "https://huggingface.co/TheBloke/Llama-2-7B-AWQ", # Overriding default `source` "source_format": "awq" } ] } ] } ``` The script will automatically run quantization and upload them to the following repos: - https://huggingface.co/junrushao/Llama-2-7b-hf-q0f16 - https://huggingface.co/junrushao/Llama-2-7b-hf-q0f32 - https://huggingface.co/junrushao/Llama-2-7b-hf-q3f16_1 - https://huggingface.co/junrushao/Llama-2-7b-hf-q4f16_1 - https://huggingface.co/junrushao/Llama-2-7b-hf-q4f32_1 - https://huggingface.co/junrushao/Llama-2-7b-hf-q4f16_awq --- python/mlc_chat/cli/delivery.py | 257 ++++++++++++++++++ python/mlc_chat/compiler/convert_weight.py | 21 +- .../compiler/loader/huggingface_loader.py | 2 +- python/mlc_chat/support/download.py | 31 ++- 4 files changed, 298 insertions(+), 13 deletions(-) create mode 100644 python/mlc_chat/cli/delivery.py diff --git a/python/mlc_chat/cli/delivery.py b/python/mlc_chat/cli/delivery.py new file mode 100644 index 0000000000..3ec0f3f32f --- /dev/null +++ b/python/mlc_chat/cli/delivery.py @@ -0,0 +1,257 @@ +"""Continuous model delivery for MLC LLM models.""" +import argparse +import dataclasses +import json +import logging +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Callable, Dict, List, Tuple, Union + +from huggingface_hub import HfApi # pylint: disable=import-error +from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error + +from ..support.argparse import ArgumentParser +from ..support.download import git_clone +from ..support.style import bold, green, red + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ModelInfo: + """Necessary information for the model delivery""" + + model_id: str + model: Path + conv_template: str + context_window_size: int + quantization: str + source_format: str = "auto" + + +class DeferredScope: + """A context manager that defers execution of functions until exiting the scope.""" + + def __init__(self): + self.deferred_functions = [] + + def add(self, func: Callable[[], None]): + """Add a function to be executed when exiting the scope.""" + self.deferred_functions.append(func) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + for func in reversed(self.deferred_functions): + func() + return False + + def create_temp_dir(self) -> Path: + """Create a temporary directory that will be deleted when exiting the scope.""" + temp_dir = tempfile.mkdtemp() + self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True)) + return Path(temp_dir) + + +def _clone_repo(model: Union[str, Path], deferred: DeferredScope) -> Path: + if isinstance(model, Path): + if not model.exists(): + raise ValueError(f"Invalid model source: {model}") + return model + if model.startswith("https://") or model.startswith("git://"): + result = deferred.create_temp_dir() / "repo" + git_clone(model, result, ignore_lfs=False) + return result + result = Path(model) + if result.exists(): + return result + raise ValueError(f"Invalid model source: {model}") + + +def _run_quantization( + model_info: ModelInfo, + repo: str, + api: HfApi, +) -> bool: + logger.info("[HF] Creating repo https://huggingface.co/%s", repo) + try: + api.create_repo(repo_id=repo, private=False) + except HfHubHTTPError as error: + if error.response.status_code != 409: + raise + logger.info("[HF] Repo already exists. Recreating...") + api.delete_repo(repo_id=repo) + api.create_repo(repo_id=repo, private=False) + logger.info("[HF] Repo recreated") + succeeded = True + with tempfile.TemporaryDirectory() as output_dir: + log_path = Path(output_dir) / "logs.txt" + with log_path.open("a", encoding="utf-8") as log_file: + assert isinstance(model_info.model, Path) + logger.info("[MLC] Processing in directory: %s", output_dir) + subprocess.run( + [ + "mlc_chat", + "gen_mlc_chat_config", + "--model", + str(model_info.model), + "--quantization", + model_info.quantization, + "--conv-template", + model_info.conv_template, + "--context-window-size", + str(model_info.context_window_size), + "--output", + output_dir, + ], + check=True, + stdout=log_file, + stderr=subprocess.STDOUT, + ) + subprocess.run( + [ + "mlc_chat", + "convert_weight", + "--model", + str(model_info.model), + "--quantization", + model_info.quantization, + "--source-format", + model_info.source_format, + "--output", + output_dir, + ], + check=False, + stdout=log_file, + stderr=subprocess.STDOUT, + ) + logger.info("[MLC] Complete!") + if not (Path(output_dir) / "ndarray-cache.json").exists(): + logger.error( + "[%s] Model %s. Quantization %s. No weights metadata found.", + red("FAILED"), + model_info.model_id, + model_info.quantization, + ) + succeeded = False + logger.info("[HF] Uploading to: https://huggingface.co/%s", repo) + api.upload_folder( + folder_path=output_dir, + repo_id=repo, + commit_message="Initial commit", + ) + return succeeded + + +def _main( # pylint: disable=too-many-locals + username: str, + api: HfApi, + spec: Dict[str, Any], +): + failed_cases: List[Tuple[str, str]] = [] + for task_index, task in enumerate(spec["tasks"], 1): + with DeferredScope() as deferred: + logger.info( + bold("[{task_index}/{total_tasks}] Processing model: ").format( + task_index=task_index, + total_tasks=len(spec["tasks"]), + ) + + green(task["model_id"]) + ) + model = _clone_repo(task["model"], deferred) + for quantization in spec["default_quantization"] + task.get("quantization", []): + model_info = { + "model_id": task["model_id"], + "model": model, + "context_window_size": task["context_window_size"], + "conv_template": task["conv_template"], + } + if isinstance(quantization, str): + model_info["quantization"] = quantization + else: + model_info["quantization"] = quantization.pop("format") + model_info.update(quantization) + repo = spec.get("destination", "{username}/{model_id}-{quantization}").format( + username=username, + model_id=model_info["model_id"], + quantization=model_info["quantization"], + ) + logger.info( + "%s%s. %s%s. %s%s", + bold("Model: "), + green(task["model_id"]), + bold("Quantization: "), + green(model_info["quantization"]), + bold("Repo: "), + green(f"https://huggingface.co/{repo}"), + ) + with DeferredScope() as inner_deferred: + model_info["model"] = _clone_repo(model_info["model"], inner_deferred) + result = _run_quantization( + ModelInfo(**model_info), + repo=spec["destination"].format( + username=username, + model_id=model_info["model_id"], + quantization=model_info["quantization"], + ), + api=api, + ) + if not result: + failed_cases.append( + (task["model_id"], model_info["quantization"]), + ) + if failed_cases: + logger.info("Total %s %s:", len(failed_cases), red("failures")) + for model_id, quantization in failed_cases: + logger.info(" Model %s. Quantization %s.", model_id, quantization) + + +def main(): + """Entry point.""" + + def _load_spec(path_spec: str) -> Dict[str, Any]: + path = Path(path_spec) + if not path.exists(): + raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}") + with path.open("r", encoding="utf-8") as i_f: + return json.load(i_f) + + parser = ArgumentParser("MLC LLM continuous model delivery") + parser.add_argument( + "--username", + type=str, + required=True, + help="HuggingFace username", + ) + parser.add_argument( + "--token", + type=str, + required=True, + help="HuggingFace access token, obtained under https://huggingface.co/settings/tokens", + ) + parser.add_argument( + "--spec", + type=_load_spec, + required=True, + help="Path to the spec file", + ) + parsed = parser.parse_args() + _main( + parsed.username, + spec=parsed.spec, + api=HfApi(token=parsed.token), + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/compiler/convert_weight.py b/python/mlc_chat/compiler/convert_weight.py index 0b157d2369..cd762de08a 100644 --- a/python/mlc_chat/compiler/convert_weight.py +++ b/python/mlc_chat/compiler/convert_weight.py @@ -105,14 +105,7 @@ def _check_param(name: str, param: NDArray): total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize if named_params: raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}") - # dump to output directory - tvmjs.dump_ndarray_cache( - param_dict, - str(args.output), - meta_data={"ParamSize": len(param_dict)}, - encode_format="raw", - ) - logger.info("Saved to directory: %s", bold(str(args.output))) + # Log necessary statistics logger.info( "%s after quantization: %.3f GB", green("Parameter size"), @@ -124,6 +117,18 @@ def _check_param(name: str, param: NDArray): green("Bits per parameter"), total_bytes * 8.0 / total_params, ) + # dump to output directory + tvmjs.dump_ndarray_cache( + param_dict, + str(args.output), + meta_data={ + "ParamSize": len(param_dict), + "ParamBytes": total_bytes, + "BitsPerParam": total_bytes * 8.0 / total_params, + }, + encode_format="raw", + ) + logger.info("Saved to directory: %s", bold(str(args.output))) def convert_weight( # pylint: disable=too-many-arguments diff --git a/python/mlc_chat/compiler/loader/huggingface_loader.py b/python/mlc_chat/compiler/loader/huggingface_loader.py index 651c43b21f..9611cda87f 100644 --- a/python/mlc_chat/compiler/loader/huggingface_loader.py +++ b/python/mlc_chat/compiler/loader/huggingface_loader.py @@ -77,7 +77,7 @@ def __init__( The quantization mapping from MLC to quantized MLC parameters, default to None, which means no quantization. """ - assert path.is_file() + assert path.is_file(), f"Path {path} is not a file" self.stats = Stats() self.extern_param_map = extern_param_map self.cached_files = {} diff --git a/python/mlc_chat/support/download.py b/python/mlc_chat/support/download.py index 6716cfdac4..ff24a76873 100644 --- a/python/mlc_chat/support/download.py +++ b/python/mlc_chat/support/download.py @@ -52,14 +52,14 @@ def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None: path.parent.mkdir(parents=True, exist_ok=True) -def git_clone(url: str, destination: Path) -> None: +def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None: """Clone a git repository into a directory.""" repo_name = ".tmp" command = ["git", "clone", url, repo_name] _ensure_directory_not_exist(destination, force_redo=False) try: with tempfile.TemporaryDirectory() as tmp_dir: - logger.info("Cloning git repo %s to %s", url, destination) + logger.info("[Git] Cloning %s to %s", url, destination) subprocess.run( command, env={"GIT_LFS_SKIP_SMUDGE": "1"}, @@ -68,7 +68,10 @@ def git_clone(url: str, destination: Path) -> None: stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) - shutil.move(os.path.join(tmp_dir, repo_name), str(destination)) + git_dir = os.path.join(tmp_dir, repo_name) + if not ignore_lfs: + git_lfs_pull(Path(git_dir)) + shutil.move(git_dir, str(destination)) except subprocess.CalledProcessError as error: raise ValueError( f"Git clone failed with return code {error.returncode}: {error.stderr}. " @@ -76,6 +79,26 @@ def git_clone(url: str, destination: Path) -> None: ) from error +def git_lfs_pull(repo_dir: Path) -> None: + """Pull files with Git LFS.""" + filenames = ( + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "ls-files", "-n"], + stderr=subprocess.STDOUT, + ) + .decode("utf-8") + .splitlines() + ) + logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames) + with tqdm.redirect(): + for file in tqdm.tqdm(filenames): + logger.info("[Git LFS] Downloading %s", file) + subprocess.check_output( + ["git", "-C", str(repo_dir), "lfs", "pull", file], + stderr=subprocess.STDOUT, + ) + + def download_file( url: str, destination: Path, @@ -124,7 +147,7 @@ def download_mlc_weights( # pylint: disable=too-many-locals with tempfile.TemporaryDirectory() as tmp_dir_prefix: tmp_dir = Path(tmp_dir_prefix) / "tmp" git_url = git_url_template.format(user=user, repo=repo) - git_clone(git_url, tmp_dir) + git_clone(git_url, tmp_dir, ignore_lfs=True) shutil.rmtree(tmp_dir / ".git", ignore_errors=True) with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file: param_metadata = json.load(in_file)["records"]