Skip to content

Commit

Permalink
Continuous Model Delivery
Browse files Browse the repository at this point in the history
This PR provides a script that automatically quantizes models from
HuggingFace using various quantization formats as specified.

Example: When being provided the following JSON file:

```json
{
  "destination": "{username}/{model_id}-{quantization}", # Name of HF repo
  "default_quantization": ["q0f16", "q0f32", "q3f16_1", "q4f16_1", "q4f32_1"],
  "tasks": [
    {
      "model_id": "Llama-2-7b-hf",
      "model": "/models/Llama-2-7b-hf", # Can be HF URL or a local path
      "context_window_size": 4096,
      "conv_template": "LM",
      "quantization": [
        {
          "format": "q4f16_awq",
          "model": "https://huggingface.co/TheBloke/Llama-2-7B-AWQ", # Overriding default `source`
          "source_format": "awq"
        }
      ]
    }
  ]
}
```

The script will automatically run quantization and upload them to the
following repos:
- https://huggingface.co/junrushao/Llama-2-7b-hf-q0f16
- https://huggingface.co/junrushao/Llama-2-7b-hf-q0f32
- https://huggingface.co/junrushao/Llama-2-7b-hf-q3f16_1
- https://huggingface.co/junrushao/Llama-2-7b-hf-q4f16_1
- https://huggingface.co/junrushao/Llama-2-7b-hf-q4f32_1
- https://huggingface.co/junrushao/Llama-2-7b-hf-q4f16_awq
  • Loading branch information
junrushao committed Nov 16, 2023
1 parent ceb27d5 commit 2233e3d
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 13 deletions.
258 changes: 258 additions & 0 deletions python/mlc_chat/cli/delivery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""Continuous model delivery for MLC LLM models."""
import argparse
import dataclasses
import json
import logging
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union

from huggingface_hub import HfApi # pylint: disable=import-error
from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error

from ..support.argparse import ArgumentParser
from ..support.download import git_clone
from ..support.style import bold, green, red

logging.basicConfig(
level=logging.INFO,
style="{",
datefmt="%Y-%m-%d %H:%M:%S",
format="[{asctime}] {levelname} {filename}:{lineno}: {message}",
)

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ModelInfo:
"""Necessary information for the model delivery"""

model_id: str
model: Path
conv_template: str
context_window_size: int
quantization: str
source_format: str = "auto"


class DeferredScope:
"""A context manager that defers execution of functions until exiting the scope."""

def __init__(self):
self.deferred_functions = []

def add(self, func: Callable[[], None]):
"""Add a function to be executed when exiting the scope."""
self.deferred_functions.append(func)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
for func in reversed(self.deferred_functions):
func()
return False

def create_temp_dir(self) -> Path:
"""Create a temporary directory that will be deleted when exiting the scope."""
temp_dir = tempfile.mkdtemp()
self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True))
return Path(temp_dir)


def _clone_repo(model: Union[str, Path], deferred: DeferredScope) -> Path:
if isinstance(model, Path):
if not model.exists():
raise ValueError(f"Invalid model source: {model}")
return model
if model.startswith("https://") or model.startswith("git://"):
result = deferred.create_temp_dir() / "repo"
git_clone(model, result, ignore_lfs=False)
return result
result = Path(model)
if result.exists():
return result
raise ValueError(f"Invalid model source: {model}")


def _run_quantization(
model_info: ModelInfo,
repo: str,
api: HfApi,
) -> bool:
logger.info("[HF] Creating repo https://huggingface.co/%s", repo)
try:
api.create_repo(repo_id=repo, private=False)
except HfHubHTTPError as error:
if error.response.status_code != 409:
raise
logger.info("[HF] Repo already exists. Recreating...")
api.delete_repo(repo_id=repo)
api.create_repo(repo_id=repo, private=False)
logger.info("[HF] Repo recreated")
succeeded = True
with tempfile.TemporaryDirectory() as output_dir:
log_path = Path(output_dir) / "logs.txt"
with log_path.open("a", encoding="utf-8") as log_file:
assert isinstance(model_info.model, Path)
logger.info("[MLC] Processing in directory: %s", output_dir)
subprocess.run(
[
"mlc_chat",
"gen_mlc_chat_config",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--conv-template",
model_info.conv_template,
"--context-window-size",
str(model_info.context_window_size),
"--output",
output_dir,
],
check=True,
stdout=log_file,
stderr=subprocess.STDOUT,
)
subprocess.run(
[
"mlc_chat",
"convert_weight",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--source-format",
model_info.source_format,
"--output",
output_dir,
],
check=False,
stdout=log_file,
stderr=subprocess.STDOUT,
)
logger.info("[MLC] Complete!")
if not (Path(output_dir) / "ndarray-cache.json").exists():
logger.error(
"[%s] Model %s. Quantization %s. No weights metadata found.",
red("FAILED"),
model_info.model_id,
model_info.quantization,
)
succeeded = False
logger.info("[MLC] Complete!")
logger.info("[HF] Uploading to: https://huggingface.co/%s", repo)
api.upload_folder(
folder_path=output_dir,
repo_id=repo,
commit_message="Initial commit",
)
return succeeded


def _main( # pylint: disable=too-many-locals
username: str,
api: HfApi,
spec: Dict[str, Any],
):
failed_cases: List[Tuple[str, str]] = []
for task_index, task in enumerate(spec["tasks"], 1):
with DeferredScope() as deferred:
logger.info(
bold("[{task_index}/{total_tasks}] Processing model: ").format(
task_index=task_index,
total_tasks=len(spec["tasks"]),
)
+ green(task["model_id"])
)
model = _clone_repo(task["model"], deferred)
for quantization in spec["default_quantization"] + task.get("quantization", []):
model_info = {
"model_id": task["model_id"],
"model": model,
"context_window_size": task["context_window_size"],
"conv_template": task["conv_template"],
}
if isinstance(quantization, str):
model_info["quantization"] = quantization
else:
model_info["quantization"] = quantization.pop("format")
model_info.update(quantization)
repo = spec.get("destination", "{username}/{model_id}-{quantization}").format(
username=username,
model_id=model_info["model_id"],
quantization=model_info["quantization"],
)
logger.info(
"%s%s. %s%s. %s%s",
bold("Model: "),
green(task["model_id"]),
bold("Quantization: "),
green(model_info["quantization"]),
bold("Repo: "),
green(f"https://huggingface.co/{repo}"),
)
with DeferredScope() as inner_deferred:
model_info["model"] = _clone_repo(model_info["model"], inner_deferred)
result = _run_quantization(
ModelInfo(**model_info),
repo=spec["destination"].format(
username=username,
model_id=model_info["model_id"],
quantization=model_info["quantization"],
),
api=api,
)
if not result:
failed_cases.append(
(task["model_id"], model_info["quantization"]),
)
if failed_cases:
logger.info("Total %s %s:", len(failed_cases), red("failures"))
for model_id, quantization in failed_cases:
logger.info(" Model %s. Quantization %s.", model_id, quantization)


def main():
"""Entry point."""

def _load_spec(path_spec: str) -> Dict[str, Any]:
path = Path(path_spec)
if not path.exists():
raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}")
with path.open("r", encoding="utf-8") as i_f:
return json.load(i_f)

parser = ArgumentParser("MLC LLM continuous model delivery")
parser.add_argument(
"--username",
type=str,
required=True,
help="HuggingFace username",
)
parser.add_argument(
"--token",
type=str,
required=True,
help="HuggingFace access token, obtained under https://huggingface.co/settings/tokens",
)
parser.add_argument(
"--spec",
type=_load_spec,
required=True,
help="Path to the spec file",
)
parsed = parser.parse_args()
_main(
parsed.username,
spec=parsed.spec,
api=HfApi(token=parsed.token),
)


if __name__ == "__main__":
main()
21 changes: 13 additions & 8 deletions python/mlc_chat/compiler/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,7 @@ def _check_param(name: str, param: NDArray):
total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize
if named_params:
raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}")
# dump to output directory
tvmjs.dump_ndarray_cache(
param_dict,
str(args.output),
meta_data={"ParamSize": len(param_dict)},
encode_format="raw",
)
logger.info("Saved to directory: %s", bold(str(args.output)))
# Log necessary statistics
logger.info(
"%s after quantization: %.3f GB",
green("Parameter size"),
Expand All @@ -124,6 +117,18 @@ def _check_param(name: str, param: NDArray):
green("Bits per parameter"),
total_bytes * 8.0 / total_params,
)
# dump to output directory
tvmjs.dump_ndarray_cache(
param_dict,
str(args.output),
meta_data={
"ParamSize": len(param_dict),
"ParamBytes": total_bytes,
"BitsPerParam": total_bytes * 8.0 / total_params,
},
encode_format="raw",
)
logger.info("Saved to directory: %s", bold(str(args.output)))


def convert_weight( # pylint: disable=too-many-arguments
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/loader/huggingface_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
The quantization mapping from MLC to quantized MLC parameters, default to None, which
means no quantization.
"""
assert path.is_file()
assert path.is_file(), f"Path {path} is not a file"
self.stats = Stats()
self.extern_param_map = extern_param_map
self.cached_files = {}
Expand Down
31 changes: 27 additions & 4 deletions python/mlc_chat/support/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None:
path.parent.mkdir(parents=True, exist_ok=True)


def git_clone(url: str, destination: Path) -> None:
def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None:
"""Clone a git repository into a directory."""
repo_name = ".tmp"
command = ["git", "clone", url, repo_name]
_ensure_directory_not_exist(destination, force_redo=False)
try:
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info("Cloning git repo %s to %s", url, destination)
logger.info("[Git] Cloning %s to %s", url, destination)
subprocess.run(
command,
env={"GIT_LFS_SKIP_SMUDGE": "1"},
Expand All @@ -68,14 +68,37 @@ def git_clone(url: str, destination: Path) -> None:
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
shutil.move(os.path.join(tmp_dir, repo_name), str(destination))
git_dir = os.path.join(tmp_dir, repo_name)
if not ignore_lfs:
git_lfs_pull(Path(git_dir))
shutil.move(git_dir, str(destination))
except subprocess.CalledProcessError as error:
raise ValueError(
f"Git clone failed with return code {error.returncode}: {error.stderr}. "
f"The command was: {command}"
) from error


def git_lfs_pull(repo_dir: Path) -> None:
"""Pull files with Git LFS."""
filenames = (
subprocess.check_output(
["git", "-C", str(repo_dir), "lfs", "ls-files", "-n"],
stderr=subprocess.STDOUT,
)
.decode("utf-8")
.splitlines()
)
logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames)
with tqdm.redirect():
for file in tqdm.tqdm(filenames):
logger.info("[Git LFS] Downloading %s", file)
subprocess.check_output(
["git", "-C", str(repo_dir), "lfs", "pull", file],
stderr=subprocess.STDOUT,
)


def download_file(
url: str,
destination: Path,
Expand Down Expand Up @@ -124,7 +147,7 @@ def download_mlc_weights( # pylint: disable=too-many-locals
with tempfile.TemporaryDirectory() as tmp_dir_prefix:
tmp_dir = Path(tmp_dir_prefix) / "tmp"
git_url = git_url_template.format(user=user, repo=repo)
git_clone(git_url, tmp_dir)
git_clone(git_url, tmp_dir, ignore_lfs=True)
shutil.rmtree(tmp_dir / ".git", ignore_errors=True)
with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file:
param_metadata = json.load(in_file)["records"]
Expand Down

0 comments on commit 2233e3d

Please sign in to comment.