From c5500f1eb061f264786040a4521d62187b307159 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 26 May 2024 08:59:48 -0400 Subject: [PATCH] [DOWNLOAD] MLC_DOWNLOAD_POLICY and MLC_LLM_READONLY_WEIGHT_CACHES This PR introduces support for MLC_DOWNLOAD_POLICY and MLC_LLM_READONLY_WEIGHT_CACHES Allows reading from readonly cache besides MLC_LLM_HOME. Also introduces a domain subfolder in cached weights --- python/mlc_llm/chat_module.py | 4 +-- python/mlc_llm/support/auto_config.py | 4 +-- python/mlc_llm/support/constants.py | 27 +++++++++++---- python/mlc_llm/support/download.py | 50 ++++++++++++++++++++++++--- 4 files changed, 71 insertions(+), 14 deletions(-) diff --git a/python/mlc_llm/chat_module.py b/python/mlc_llm/chat_module.py index 72d1e5315e..b0eed3cbcf 100644 --- a/python/mlc_llm/chat_module.py +++ b/python/mlc_llm/chat_module.py @@ -356,11 +356,11 @@ def _get_model_path(model: str) -> Tuple[str, str]: """ if model.startswith("HF://"): from mlc_llm.support.download import ( # pylint: disable=import-outside-toplevel - download_mlc_weights, + download_and_cache_mlc_weights, ) logger.info("Downloading model from HuggingFace: %s", model) - mlc_dir = download_mlc_weights(model) + mlc_dir = download_and_cache_mlc_weights(model) cfg_dir = mlc_dir / "mlc-chat-config.json" return str(mlc_dir), str(cfg_dir) diff --git a/python/mlc_llm/support/auto_config.py b/python/mlc_llm/support/auto_config.py index be0ee8af98..f518439c66 100644 --- a/python/mlc_llm/support/auto_config.py +++ b/python/mlc_llm/support/auto_config.py @@ -35,12 +35,12 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path: # pylint: disable=import-outside-toplevel from mlc_llm.model import MODEL_PRESETS - from .download import download_mlc_weights + from .download import download_and_cache_mlc_weights # pylint: enable=import-outside-toplevel if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"): - mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config)) + mlc_chat_config_path = Path(download_and_cache_mlc_weights(model_url=mlc_chat_config)) elif isinstance(mlc_chat_config, str) and mlc_chat_config in MODEL_PRESETS: logger.info("%s mlc preset model: %s", FOUND, mlc_chat_config) content = MODEL_PRESETS[mlc_chat_config].copy() diff --git a/python/mlc_llm/support/constants.py b/python/mlc_llm/support/constants.py index beb402653c..e2638bd703 100644 --- a/python/mlc_llm/support/constants.py +++ b/python/mlc_llm/support/constants.py @@ -13,6 +13,13 @@ def _check(): f"but got {MLC_JIT_POLICY}." ) + if MLC_DOWNLOAD_POLICY not in ["ON", "OFF", "REDO", "READONLY"]: + raise ValueError( + "Invalid MLC_AUTO_DOWNLOAD_POLICY. " + 'It has to be one of "ON", "OFF", "REDO", "READONLY"' + f"but got {MLC_DOWNLOAD_POLICY}." + ) + def _get_cache_dir() -> Path: if "MLC_LLM_HOME" in os.environ: @@ -48,23 +55,31 @@ def _get_dso_suffix() -> str: def _get_test_model_path() -> List[Path]: - if "MLC_TEST_MODEL_PATH" in os.environ: - return [Path(p) for p in os.environ["MLC_TEST_MODEL_PATH"].split(os.pathsep)] + if "MLC_LLM_TEST_MODEL_PATH" in os.environ: + return [Path(p) for p in os.environ["MLC_LLM_TEST_MODEL_PATH"].split(os.pathsep)] # by default, we reuse the cache dir via mlc_llm chat # note that we do not auto download for testcase # to avoid networking dependencies - return [ - _get_cache_dir() / "model_weights" / "mlc-ai", - Path(os.path.abspath(os.path.curdir)), + base_list = ["hf"] + return [_get_cache_dir() / "model_weights" / base / "mlc-ai" for base in base_list] + [ + Path(os.path.abspath(os.path.curdir)) ] +def _get_read_only_weight_caches() -> List[Path]: + if "MLC_LLM_READONLY_WEIGHT_CACHES" in os.environ: + return [Path(p) for p in os.environ["MLC_LLM_READONLY_WEIGHT_CACHES"].split(os.pathsep)] + return [] + + MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None) MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None) -MLC_LLM_HOME: Path = _get_cache_dir() MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON") MLC_DSO_SUFFIX = _get_dso_suffix() MLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path() +MLC_DOWNLOAD_POLICY = os.environ.get("MLC_DOWNLOAD_POLICY", "ON") +MLC_LLM_HOME: Path = _get_cache_dir() +MLC_LLM_READONLY_WEIGHT_CACHES = _get_read_only_weight_caches() _check() diff --git a/python/mlc_llm/support/download.py b/python/mlc_llm/support/download.py index 0b520d69c0..3dcc34cd24 100644 --- a/python/mlc_llm/support/download.py +++ b/python/mlc_llm/support/download.py @@ -13,12 +13,26 @@ import requests # pylint: disable=import-error from . import logging, tqdm -from .constants import MLC_LLM_HOME, MLC_TEMP_DIR +from .constants import ( + MLC_DOWNLOAD_POLICY, + MLC_LLM_HOME, + MLC_LLM_READONLY_WEIGHT_CACHES, + MLC_TEMP_DIR, +) from .style import bold logger = logging.getLogger(__name__) +def log_download_policy(): + """log current download policy""" + logger.info( + "%s = %s. Can be one of: ON, OFF, REDO, READONLY", + bold("MLC_DOWNLOAD_POLICY"), + MLC_DOWNLOAD_POLICY, + ) + + def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None: if path.exists(): if force_redo: @@ -110,12 +124,16 @@ def download_file( return url, destination -def download_mlc_weights( # pylint: disable=too-many-locals +def download_and_cache_mlc_weights( # pylint: disable=too-many-locals model_url: str, num_processes: int = 4, - force_redo: bool = False, + force_redo: Optional[bool] = None, ) -> Path: """Download weights for a model from the HuggingFace Git LFS repo.""" + log_download_policy() + if MLC_DOWNLOAD_POLICY == "OFF": + raise RuntimeError(f"Cannot download {model_url} as MLC_DOWNLOAD_POLICY=OFF") + prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], "" mlc_prefix = next(p for p in prefixes if model_url.startswith(p)) assert mlc_prefix @@ -126,12 +144,36 @@ def download_mlc_weights( # pylint: disable=too-many-locals if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix): raise ValueError(f"Invalid model URL: {model_url}") user, repo = model_url[len(mlc_prefix) :].split("/") - git_dir = MLC_LLM_HOME / "model_weights" / user / repo + domain = "hf" + + readonly_cache_dirs = [] + for base in MLC_LLM_READONLY_WEIGHT_CACHES: + cache_dir = base / domain / user / repo + readonly_cache_dirs.append(str(cache_dir)) + if (cache_dir / "mlc-chat-config.json").is_file(): + logger.info("Use cached weight: %s", bold(str(cache_dir))) + return cache_dir + + if force_redo is None: + force_redo = MLC_DOWNLOAD_POLICY == "REDO" + + git_dir = MLC_LLM_HOME / "model_weights" / domain / user / repo + readonly_cache_dirs.append(str(git_dir)) + try: _ensure_directory_not_exist(git_dir, force_redo=force_redo) except ValueError: logger.info("Weights already downloaded: %s", bold(str(git_dir))) return git_dir + + if MLC_DOWNLOAD_POLICY == "READONLY": + raise RuntimeError( + f"Cannot find cache for {model_url}, " + "cannot proceed to download as MLC_DOWNLOAD_POLICY=READONLY, " + "please check settings MLC_LLM_READONLY_WEIGHT_CACHES, " + f"local path candidates: {readonly_cache_dirs}" + ) + with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix: tmp_dir = Path(tmp_dir_prefix) / "tmp" git_url = git_url_template.format(user=user, repo=repo)