Skip to content

Commit

Permalink
[TESTING] Introduce testing util to manage models
Browse files Browse the repository at this point in the history
This PR introduce a new env var MLC_TEST_MODEL_PATH to allow a list of model path
specified for test model search purposes.

If not found, an error message would appear and we auto skip test in both
pytest and normal running settings.

The path defaults to the cached HF path so as long as we run mlc_llm chat
the model can be found. But we do not automatically download to avoid
excessive networking in CI settings.

Followup PR needed for remaining testcases
  • Loading branch information
tqchen committed May 22, 2024
1 parent a5e71b3 commit 1ceb1a2
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 7 deletions.
11 changes: 11 additions & 0 deletions python/mlc_llm/support/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
from pathlib import Path
from typing import List


def _check():
Expand Down Expand Up @@ -45,11 +46,21 @@ def _get_dso_suffix() -> str:
return "so"


def _get_test_model_path() -> List[Path]:
if "MLC_TEST_MODEL_PATH" in os.environ:
return [Path(p) for p in os.environ["MLC_TEST_MODEL_PATH"].split(os.pathsep)]
# by default, we reuse the cache dir via mlc_llm chat
# note that we do not auto download for testcase
# to avoid networking dependencies
return [_get_cache_dir() / "model_weights" / "mlc-ai"]


MLC_TEMP_DIR = os.getenv("MLC_TEMP_DIR", None)
MLC_MULTI_ARCH = os.environ.get("MLC_MULTI_ARCH", None)
MLC_CACHE_DIR: Path = _get_cache_dir()
MLC_JIT_POLICY = os.environ.get("MLC_JIT_POLICY", "ON")
MLC_DSO_SUFFIX = _get_dso_suffix()
MLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path()


_check()
1 change: 1 addition & 0 deletions python/mlc_llm/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""
Test and debug tools for MLC LLM
"""
from .pytest_utils import require_test_model
55 changes: 55 additions & 0 deletions python/mlc_llm/testing/pytest_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Extra utilities to mark tests"""
import functools
import warnings
from pathlib import Path
from typing import Callable

import pytest

from mlc_llm.support.constants import MLC_TEST_MODEL_PATH


def require_test_model(model: str):
"""Testcase decorator to require a model
Examples
--------
.. code::
@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC")
def test_reload_reset_unload(model):
# model now points to the right path
# specified by MLC_TEST_MODEL_PATH
engine = mlc_llm.MLCEngine(model)
# test code follows
Parameters
----------
model : str
The model dir name
"""
model_path = None
for base_path in MLC_TEST_MODEL_PATH:
if (base_path / model / "mlc-chat-config.json").is_file():
model_path = base_path / model
missing_model = model_path is None
message = (
f"Model {model} does not exist in candidate paths {[str(p) for p in MLC_TEST_MODEL_PATH]},"
" if you set MLC_TEST_MODEL_PATH env var, please ensure model paths are in the right location,"
" by default we reuse cache with mlc_llm chat, try to run mlc_llm chat to download right set of models."
)

def _decorator(func: Callable[[str], None]):
wrapped = functools.partial(func, str(model_path))
wrapped.__name__ = func.__name__

@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
if missing_model:
print(f"{message} skipping...")
return
return wrapped(*args, **kwargs)

return pytest.mark.skipif(missing_model, reason=message)(wrapper)

return _decorator
13 changes: 7 additions & 6 deletions tests/python/json_ffi/test_json_ffi_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel

from mlc_llm.json_ffi import JSONFFIEngine
from mlc_llm.testing import require_test_model

chat_completion_prompts = [
"What is the meaning of life?",
Expand Down Expand Up @@ -142,9 +143,9 @@ class Schema(BaseModel):
print(f"Output {req_id}({i}):{output}\n")


def test_chat_completion():
@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC")
def test_chat_completion(model):
# Create engine.
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC"
engine = JSONFFIEngine(
model,
max_total_sequence_length=1024,
Expand All @@ -160,9 +161,9 @@ def test_chat_completion():
engine.terminate()


def test_reload_reset_unload():
@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC")
def test_reload_reset_unload(model):
# Create engine.
model = "HF://mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC"
engine = JSONFFIEngine(
model,
max_total_sequence_length=1024,
Expand All @@ -179,8 +180,8 @@ def test_reload_reset_unload():
engine.terminate()


def test_json_schema_with_system_prompt():
model = "HF://mlc-ai/Hermes-2-Pro-Mistral-7B-q4f16_1-MLC"
@require_test_model("Hermes-2-Pro-Mistral-7B-q4f16_1-MLC")
def test_json_schema_with_system_prompt(model):
engine = JSONFFIEngine(
model,
max_total_sequence_length=1024,
Expand Down
3 changes: 2 additions & 1 deletion tests/python/json_ffi/test_json_ffi_engine_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests

from mlc_llm.json_ffi import JSONFFIEngine
from mlc_llm.testing import require_test_model


def base64_encode_image(url: str) -> str:
Expand Down Expand Up @@ -69,9 +70,9 @@ def run_chat_completion(
print(f"Output {req_id}({i}):{output}\n")


@require_test_model("llava-1.5-7b-hf-q4f16_1-MLC")
def test_chat_completion():
# Create engine.
model = "dist/llava-1.5-7b-hf-q4f16_1-MLC"
engine = JSONFFIEngine(
model,
max_total_sequence_length=1024,
Expand Down

0 comments on commit 1ceb1a2

Please sign in to comment.