diff --git a/python/mlc_llm/support/constants.py b/python/mlc_llm/support/constants.py index 82697ff71a..989e4b3191 100644 --- a/python/mlc_llm/support/constants.py +++ b/python/mlc_llm/support/constants.py @@ -2,6 +2,7 @@ import os import sys from pathlib import Path +from typing import List def _check(): @@ -45,11 +46,21 @@ def _get_dso_suffix() -> str: return "so" +def _get_test_model_path() -> List[Path]: + if "MLC_TEST_MODEL_PATH" in os.environ: + return [Path(p) for p in os.environ["MLC_TEST_MODEL_PATH"].split(os.pathsep)] + # by default, we reuse the cache dir via mlc_llm chat + # note that we do not auto download for testcase + # to avoid networking dependencies + return [_get_cache_dir() / "model_weights" / "mlc-ai"] + + MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None) MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None) MLC_CACHE_DIR: Path = _get_cache_dir() MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON") MLC_DSO_SUFFIX = _get_dso_suffix() +MLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path() _check() diff --git a/python/mlc_llm/testing/__init__.py b/python/mlc_llm/testing/__init__.py index e803641043..be72325b94 100644 --- a/python/mlc_llm/testing/__init__.py +++ b/python/mlc_llm/testing/__init__.py @@ -1,3 +1,4 @@ """ Test and debug tools for MLC LLM """ +from .pytest_utils import require_test_model diff --git a/python/mlc_llm/testing/pytest_utils.py b/python/mlc_llm/testing/pytest_utils.py new file mode 100644 index 0000000000..850f4c6b82 --- /dev/null +++ b/python/mlc_llm/testing/pytest_utils.py @@ -0,0 +1,53 @@ +"""Extra utilities to mark tests""" +import functools +from typing import Callable + +import pytest + +from mlc_llm.support.constants import MLC_TEST_MODEL_PATH + + +def require_test_model(model: str): + """Testcase decorator to require a model + + Examples + -------- + .. code:: + + @require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") + def test_reload_reset_unload(model): + # model now points to the right path + # specified by MLC_TEST_MODEL_PATH + engine = mlc_llm.MLCEngine(model) + # test code follows + + Parameters + ---------- + model : str + The model dir name + """ + model_path = None + for base_path in MLC_TEST_MODEL_PATH: + if (base_path / model / "mlc-chat-config.json").is_file(): + model_path = base_path / model + missing_model = model_path is None + message = ( + f"Model {model} does not exist in candidate paths {[str(p) for p in MLC_TEST_MODEL_PATH]}," + " if you set MLC_TEST_MODEL_PATH, please ensure model paths are in the right location," + " by default we reuse cache, try to run mlc_llm chat to download right set of models." + ) + + def _decorator(func: Callable[[str], None]): + wrapped = functools.partial(func, str(model_path)) + wrapped.__name__ = func.__name__ # type: ignore + + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + if missing_model: + print(f"{message} skipping...") + return + wrapped(*args, **kwargs) + + return pytest.mark.skipif(missing_model, reason=message)(wrapper) + + return _decorator diff --git a/tests/python/json_ffi/test_json_ffi_engine.py b/tests/python/json_ffi/test_json_ffi_engine.py index 6468a93c0d..bff1ba7df0 100644 --- a/tests/python/json_ffi/test_json_ffi_engine.py +++ b/tests/python/json_ffi/test_json_ffi_engine.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from mlc_llm.json_ffi import JSONFFIEngine +from mlc_llm.testing import require_test_model chat_completion_prompts = [ "What is the meaning of life?", @@ -142,9 +143,9 @@ class Schema(BaseModel): print(f"Output {req_id}({i}):{output}\n") -def test_chat_completion(): +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") +def test_chat_completion(model): # Create engine. - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" engine = JSONFFIEngine( model, max_total_sequence_length=1024, @@ -160,9 +161,9 @@ def test_chat_completion(): engine.terminate() -def test_reload_reset_unload(): +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") +def test_reload_reset_unload(model): # Create engine. - model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC" engine = JSONFFIEngine( model, max_total_sequence_length=1024, @@ -179,8 +180,8 @@ def test_reload_reset_unload(): engine.terminate() -def test_json_schema_with_system_prompt(): - model = "HF://mlc-ai/Hermes-2-Pro-Mistral-7B-q4f16_1-MLC" +@require_test_model("Hermes-2-Pro-Mistral-7B-q4f16_1-MLC") +def test_json_schema_with_system_prompt(model): engine = JSONFFIEngine( model, max_total_sequence_length=1024, diff --git a/tests/python/json_ffi/test_json_ffi_engine_image.py b/tests/python/json_ffi/test_json_ffi_engine_image.py index 2e0cb89878..3f0399f792 100644 --- a/tests/python/json_ffi/test_json_ffi_engine_image.py +++ b/tests/python/json_ffi/test_json_ffi_engine_image.py @@ -4,6 +4,7 @@ import requests from mlc_llm.json_ffi import JSONFFIEngine +from mlc_llm.testing import require_test_model def base64_encode_image(url: str) -> str: @@ -69,9 +70,9 @@ def run_chat_completion( print(f"Output {req_id}({i}):{output}\n") +@require_test_model("llava-1.5-7b-hf-q4f16_1-MLC") def test_chat_completion(): # Create engine. - model = "dist/llava-1.5-7b-hf-q4f16_1-MLC" engine = JSONFFIEngine( model, max_total_sequence_length=1024,