Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continuous Model Delivery #1272

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions python/mlc_chat/cli/delivery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""Continuous model delivery for MLC LLM models."""
import argparse
import dataclasses
import json
import logging
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union

from huggingface_hub import HfApi # pylint: disable=import-error
from huggingface_hub.utils import HfHubHTTPError # pylint: disable=import-error

from ..support.argparse import ArgumentParser
from ..support.download import git_clone
from ..support.style import bold, green, red

logging.basicConfig(
level=logging.INFO,
style="{",
datefmt="%Y-%m-%d %H:%M:%S",
format="[{asctime}] {levelname} {filename}:{lineno}: {message}",
)

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ModelInfo:
"""Necessary information for the model delivery"""

model_id: str
model: Path
conv_template: str
context_window_size: int
quantization: str
source_format: str = "auto"


class DeferredScope:
"""A context manager that defers execution of functions until exiting the scope."""

def __init__(self):
self.deferred_functions = []

def add(self, func: Callable[[], None]):
"""Add a function to be executed when exiting the scope."""
self.deferred_functions.append(func)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
for func in reversed(self.deferred_functions):
func()
return False

def create_temp_dir(self) -> Path:
"""Create a temporary directory that will be deleted when exiting the scope."""
temp_dir = tempfile.mkdtemp()
self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True))
return Path(temp_dir)


def _clone_repo(model: Union[str, Path], deferred: DeferredScope) -> Path:
if isinstance(model, Path):
if not model.exists():
raise ValueError(f"Invalid model source: {model}")
return model
if model.startswith("https://") or model.startswith("git://"):
result = deferred.create_temp_dir() / "repo"
git_clone(model, result, ignore_lfs=False)
return result
result = Path(model)
if result.exists():
return result
raise ValueError(f"Invalid model source: {model}")


def _run_quantization(
model_info: ModelInfo,
repo: str,
api: HfApi,
) -> bool:
logger.info("[HF] Creating repo https://huggingface.co/%s", repo)
try:
api.create_repo(repo_id=repo, private=False)
except HfHubHTTPError as error:
if error.response.status_code != 409:
raise
logger.info("[HF] Repo already exists. Recreating...")
api.delete_repo(repo_id=repo)
api.create_repo(repo_id=repo, private=False)
logger.info("[HF] Repo recreated")
succeeded = True
with tempfile.TemporaryDirectory() as output_dir:
log_path = Path(output_dir) / "logs.txt"
with log_path.open("a", encoding="utf-8") as log_file:
assert isinstance(model_info.model, Path)
logger.info("[MLC] Processing in directory: %s", output_dir)
subprocess.run(
[
"mlc_chat",
"gen_mlc_chat_config",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--conv-template",
model_info.conv_template,
"--context-window-size",
str(model_info.context_window_size),
"--output",
output_dir,
],
check=True,
stdout=log_file,
stderr=subprocess.STDOUT,
)
subprocess.run(
[
"mlc_chat",
"convert_weight",
"--model",
str(model_info.model),
"--quantization",
model_info.quantization,
"--source-format",
model_info.source_format,
"--output",
output_dir,
],
check=False,
stdout=log_file,
stderr=subprocess.STDOUT,
)
logger.info("[MLC] Complete!")
if not (Path(output_dir) / "ndarray-cache.json").exists():
logger.error(
"[%s] Model %s. Quantization %s. No weights metadata found.",
red("FAILED"),
model_info.model_id,
model_info.quantization,
)
succeeded = False
logger.info("[HF] Uploading to: https://huggingface.co/%s", repo)
api.upload_folder(
folder_path=output_dir,
repo_id=repo,
commit_message="Initial commit",
)
return succeeded


def _main( # pylint: disable=too-many-locals
username: str,
api: HfApi,
spec: Dict[str, Any],
):
failed_cases: List[Tuple[str, str]] = []
for task_index, task in enumerate(spec["tasks"], 1):
with DeferredScope() as deferred:
logger.info(
bold("[{task_index}/{total_tasks}] Processing model: ").format(
task_index=task_index,
total_tasks=len(spec["tasks"]),
)
+ green(task["model_id"])
)
model = _clone_repo(task["model"], deferred)
for quantization in spec["default_quantization"] + task.get("quantization", []):
model_info = {
"model_id": task["model_id"],
"model": model,
"context_window_size": task["context_window_size"],
"conv_template": task["conv_template"],
}
if isinstance(quantization, str):
model_info["quantization"] = quantization
else:
model_info["quantization"] = quantization.pop("format")
model_info.update(quantization)
repo = spec.get("destination", "{username}/{model_id}-{quantization}").format(
username=username,
model_id=model_info["model_id"],
quantization=model_info["quantization"],
)
logger.info(
"%s%s. %s%s. %s%s",
bold("Model: "),
green(task["model_id"]),
bold("Quantization: "),
green(model_info["quantization"]),
bold("Repo: "),
green(f"https://huggingface.co/{repo}"),
)
with DeferredScope() as inner_deferred:
model_info["model"] = _clone_repo(model_info["model"], inner_deferred)
result = _run_quantization(
ModelInfo(**model_info),
repo=spec["destination"].format(
username=username,
model_id=model_info["model_id"],
quantization=model_info["quantization"],
),
api=api,
)
if not result:
failed_cases.append(
(task["model_id"], model_info["quantization"]),
)
if failed_cases:
logger.info("Total %s %s:", len(failed_cases), red("failures"))
for model_id, quantization in failed_cases:
logger.info(" Model %s. Quantization %s.", model_id, quantization)


def main():
"""Entry point."""

def _load_spec(path_spec: str) -> Dict[str, Any]:
path = Path(path_spec)
if not path.exists():
raise argparse.ArgumentTypeError(f"Spec file does not exist: {path}")
with path.open("r", encoding="utf-8") as i_f:
return json.load(i_f)

parser = ArgumentParser("MLC LLM continuous model delivery")
parser.add_argument(
"--username",
type=str,
required=True,
help="HuggingFace username",
)
parser.add_argument(
"--token",
type=str,
required=True,
help="HuggingFace access token, obtained under https://huggingface.co/settings/tokens",
)
parser.add_argument(
"--spec",
type=_load_spec,
required=True,
help="Path to the spec file",
)
parsed = parser.parse_args()
_main(
parsed.username,
spec=parsed.spec,
api=HfApi(token=parsed.token),
)


if __name__ == "__main__":
main()
21 changes: 13 additions & 8 deletions python/mlc_chat/compiler/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,7 @@ def _check_param(name: str, param: NDArray):
total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize
if named_params:
raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}")
# dump to output directory
tvmjs.dump_ndarray_cache(
param_dict,
str(args.output),
meta_data={"ParamSize": len(param_dict)},
encode_format="raw",
)
logger.info("Saved to directory: %s", bold(str(args.output)))
# Log necessary statistics
logger.info(
"%s after quantization: %.3f GB",
green("Parameter size"),
Expand All @@ -124,6 +117,18 @@ def _check_param(name: str, param: NDArray):
green("Bits per parameter"),
total_bytes * 8.0 / total_params,
)
# dump to output directory
tvmjs.dump_ndarray_cache(
param_dict,
str(args.output),
meta_data={
"ParamSize": len(param_dict),
"ParamBytes": total_bytes,
"BitsPerParam": total_bytes * 8.0 / total_params,
},
encode_format="raw",
)
logger.info("Saved to directory: %s", bold(str(args.output)))


def convert_weight( # pylint: disable=too-many-arguments
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/loader/huggingface_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
The quantization mapping from MLC to quantized MLC parameters, default to None, which
means no quantization.
"""
assert path.is_file()
assert path.is_file(), f"Path {path} is not a file"
self.stats = Stats()
self.extern_param_map = extern_param_map
self.cached_files = {}
Expand Down
31 changes: 27 additions & 4 deletions python/mlc_chat/support/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None:
path.parent.mkdir(parents=True, exist_ok=True)


def git_clone(url: str, destination: Path) -> None:
def git_clone(url: str, destination: Path, ignore_lfs: bool) -> None:
"""Clone a git repository into a directory."""
repo_name = ".tmp"
command = ["git", "clone", url, repo_name]
_ensure_directory_not_exist(destination, force_redo=False)
try:
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info("Cloning git repo %s to %s", url, destination)
logger.info("[Git] Cloning %s to %s", url, destination)
subprocess.run(
command,
env={"GIT_LFS_SKIP_SMUDGE": "1"},
Expand All @@ -68,14 +68,37 @@ def git_clone(url: str, destination: Path) -> None:
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
shutil.move(os.path.join(tmp_dir, repo_name), str(destination))
git_dir = os.path.join(tmp_dir, repo_name)
if not ignore_lfs:
git_lfs_pull(Path(git_dir))
shutil.move(git_dir, str(destination))
except subprocess.CalledProcessError as error:
raise ValueError(
f"Git clone failed with return code {error.returncode}: {error.stderr}. "
f"The command was: {command}"
) from error


def git_lfs_pull(repo_dir: Path) -> None:
"""Pull files with Git LFS."""
filenames = (
subprocess.check_output(
["git", "-C", str(repo_dir), "lfs", "ls-files", "-n"],
stderr=subprocess.STDOUT,
)
.decode("utf-8")
.splitlines()
)
logger.info("[Git LFS] Downloading %d files with Git LFS: %s", len(filenames), filenames)
with tqdm.redirect():
for file in tqdm.tqdm(filenames):
logger.info("[Git LFS] Downloading %s", file)
subprocess.check_output(
["git", "-C", str(repo_dir), "lfs", "pull", file],
stderr=subprocess.STDOUT,
)


def download_file(
url: str,
destination: Path,
Expand Down Expand Up @@ -124,7 +147,7 @@ def download_mlc_weights( # pylint: disable=too-many-locals
with tempfile.TemporaryDirectory() as tmp_dir_prefix:
tmp_dir = Path(tmp_dir_prefix) / "tmp"
git_url = git_url_template.format(user=user, repo=repo)
git_clone(git_url, tmp_dir)
git_clone(git_url, tmp_dir, ignore_lfs=True)
shutil.rmtree(tmp_dir / ".git", ignore_errors=True)
with (tmp_dir / "ndarray-cache.json").open(encoding="utf-8") as in_file:
param_metadata = json.load(in_file)["records"]
Expand Down