Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOWNLOAD] MLC_DOWNLOAD_POLICY and MLC_LLM_READONLY_WEIGHT_CACHES #2421

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/mlc_llm/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,11 @@ def _get_model_path(model: str) -> Tuple[str, str]:
"""
if model.startswith("HF://"):
from mlc_llm.support.download import ( # pylint: disable=import-outside-toplevel
download_mlc_weights,
download_and_cache_mlc_weights,
)

logger.info("Downloading model from HuggingFace: %s", model)
mlc_dir = download_mlc_weights(model)
mlc_dir = download_and_cache_mlc_weights(model)
cfg_dir = mlc_dir / "mlc-chat-config.json"
return str(mlc_dir), str(cfg_dir)

Expand Down
4 changes: 2 additions & 2 deletions python/mlc_llm/support/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path:
# pylint: disable=import-outside-toplevel
from mlc_llm.model import MODEL_PRESETS

from .download import download_mlc_weights
from .download import download_and_cache_mlc_weights

# pylint: enable=import-outside-toplevel

if mlc_chat_config.startswith("HF://") or mlc_chat_config.startswith("http"):
mlc_chat_config_path = Path(download_mlc_weights(model_url=mlc_chat_config))
mlc_chat_config_path = Path(download_and_cache_mlc_weights(model_url=mlc_chat_config))
elif isinstance(mlc_chat_config, str) and mlc_chat_config in MODEL_PRESETS:
logger.info("%s mlc preset model: %s", FOUND, mlc_chat_config)
content = MODEL_PRESETS[mlc_chat_config].copy()
Expand Down
27 changes: 21 additions & 6 deletions python/mlc_llm/support/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ def _check():
f"but got {MLC_JIT_POLICY}."
)

if MLC_DOWNLOAD_POLICY not in ["ON", "OFF", "REDO", "READONLY"]:
raise ValueError(
"Invalid MLC_AUTO_DOWNLOAD_POLICY. "
'It has to be one of "ON", "OFF", "REDO", "READONLY"'
f"but got {MLC_DOWNLOAD_POLICY}."
)


def _get_cache_dir() -> Path:
if "MLC_LLM_HOME" in os.environ:
Expand Down Expand Up @@ -48,23 +55,31 @@ def _get_dso_suffix() -> str:


def _get_test_model_path() -> List[Path]:
if "MLC_TEST_MODEL_PATH" in os.environ:
return [Path(p) for p in os.environ["MLC_TEST_MODEL_PATH"].split(os.pathsep)]
if "MLC_LLM_TEST_MODEL_PATH" in os.environ:
return [Path(p) for p in os.environ["MLC_LLM_TEST_MODEL_PATH"].split(os.pathsep)]
# by default, we reuse the cache dir via mlc_llm chat
# note that we do not auto download for testcase
# to avoid networking dependencies
return [
_get_cache_dir() / "model_weights" / "mlc-ai",
Path(os.path.abspath(os.path.curdir)),
base_list = ["hf"]
return [_get_cache_dir() / "model_weights" / base / "mlc-ai" for base in base_list] + [
Path(os.path.abspath(os.path.curdir))
]


def _get_read_only_weight_caches() -> List[Path]:
if "MLC_LLM_READONLY_WEIGHT_CACHES" in os.environ:
return [Path(p) for p in os.environ["MLC_LLM_READONLY_WEIGHT_CACHES"].split(os.pathsep)]
return []


MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)
MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None)
MLC_LLM_HOME: Path = _get_cache_dir()
MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON")
MLC_DSO_SUFFIX = _get_dso_suffix()
MLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path()

MLC_DOWNLOAD_POLICY = os.environ.get("MLC_DOWNLOAD_POLICY", "ON")
MLC_LLM_HOME: Path = _get_cache_dir()
MLC_LLM_READONLY_WEIGHT_CACHES = _get_read_only_weight_caches()

_check()
50 changes: 46 additions & 4 deletions python/mlc_llm/support/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,26 @@
import requests # pylint: disable=import-error

from . import logging, tqdm
from .constants import MLC_LLM_HOME, MLC_TEMP_DIR
from .constants import (
MLC_DOWNLOAD_POLICY,
MLC_LLM_HOME,
MLC_LLM_READONLY_WEIGHT_CACHES,
MLC_TEMP_DIR,
)
from .style import bold

logger = logging.getLogger(__name__)


def log_download_policy():
"""log current download policy"""
logger.info(
"%s = %s. Can be one of: ON, OFF, REDO, READONLY",
bold("MLC_DOWNLOAD_POLICY"),
MLC_DOWNLOAD_POLICY,
)


def _ensure_directory_not_exist(path: Path, force_redo: bool) -> None:
if path.exists():
if force_redo:
Expand Down Expand Up @@ -110,12 +124,16 @@ def download_file(
return url, destination


def download_mlc_weights( # pylint: disable=too-many-locals
def download_and_cache_mlc_weights( # pylint: disable=too-many-locals
model_url: str,
num_processes: int = 4,
force_redo: bool = False,
force_redo: Optional[bool] = None,
) -> Path:
"""Download weights for a model from the HuggingFace Git LFS repo."""
log_download_policy()
if MLC_DOWNLOAD_POLICY == "OFF":
raise RuntimeError(f"Cannot download {model_url} as MLC_DOWNLOAD_POLICY=OFF")

prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], ""
mlc_prefix = next(p for p in prefixes if model_url.startswith(p))
assert mlc_prefix
Expand All @@ -126,12 +144,36 @@ def download_mlc_weights( # pylint: disable=too-many-locals
if model_url.count("/") != 1 + mlc_prefix.count("/") or not model_url.startswith(mlc_prefix):
raise ValueError(f"Invalid model URL: {model_url}")
user, repo = model_url[len(mlc_prefix) :].split("/")
git_dir = MLC_LLM_HOME / "model_weights" / user / repo
domain = "hf"

readonly_cache_dirs = []
for base in MLC_LLM_READONLY_WEIGHT_CACHES:
cache_dir = base / domain / user / repo
readonly_cache_dirs.append(str(cache_dir))
if (cache_dir / "mlc-chat-config.json").is_file():
logger.info("Use cached weight: %s", bold(str(cache_dir)))
return cache_dir

if force_redo is None:
force_redo = MLC_DOWNLOAD_POLICY == "REDO"

git_dir = MLC_LLM_HOME / "model_weights" / domain / user / repo
readonly_cache_dirs.append(str(git_dir))

try:
_ensure_directory_not_exist(git_dir, force_redo=force_redo)
except ValueError:
logger.info("Weights already downloaded: %s", bold(str(git_dir)))
return git_dir

if MLC_DOWNLOAD_POLICY == "READONLY":
raise RuntimeError(
f"Cannot find cache for {model_url}, "
"cannot proceed to download as MLC_DOWNLOAD_POLICY=READONLY, "
"please check settings MLC_LLM_READONLY_WEIGHT_CACHES, "
f"local path candidates: {readonly_cache_dirs}"
)

with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix:
tmp_dir = Path(tmp_dir_prefix) / "tmp"
git_url = git_url_template.format(user=user, repo=repo)
Expand Down
Loading