Skip to content

Commit

Permalink
Compile Model Preset without External config.json (mlc-ai#1151)
Browse files Browse the repository at this point in the history
This PR adds support for compiling a preset of models without
having to provide a `config.json` on disk using the commands below:

```diff
python -m mlc_chat.cli.compile \
       --quantization q4f16_1 -o /tmp/1.so \
-       --config /models/Llama-2-7b-chat-hf
+       --config llama2_7b
```

This allows easier testing and binary distribution without having to
depend on external model directory.
  • Loading branch information
junrushao authored Oct 30, 2023
1 parent 0a25374 commit 1a79a53
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(Path(path))
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")

Expand Down
9 changes: 3 additions & 6 deletions python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
but users could optionally import it if they want to use the compiler.
"""
from . import compiler_pass
from .compile import ( # pylint: disable=redefined-builtin
CompileArgs,
OptimizationFlags,
compile,
)
from .model import MODELS, Model
from .compile import CompileArgs, compile # pylint: disable=redefined-builtin
from .flags_optimization import OptimizationFlags
from .model import MODEL_PRESETS, MODELS, Model
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
from .quantization import QUANT
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Model definition for the compiler."""
from .model import MODELS, Model
from .model import MODEL_PRESETS, MODELS, Model
2 changes: 2 additions & 0 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,5 @@ class Model:
quantize={},
)
}

MODEL_PRESETS: Dict[str, Dict[str, Any]] = llama_config.CONFIG
30 changes: 24 additions & 6 deletions python/mlc_chat/support/auto_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Help function for detecting the model configuration file `config.json`"""
import json
import logging
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

from .style import green

Expand All @@ -14,25 +15,42 @@
FOUND = green("Found")


def detect_config(config_path: Path) -> Path:
"""Detect and return the path that points to config.json. If config_path is a directory,
def detect_config(config: Union[str, Path]) -> Path:
"""Detect and return the path that points to config.json. If `config` is a directory,
it looks for config.json below it.
Parameters
---------
config_path : pathlib.Path
The path to config.json or the directory containing config.json.
config : Union[str, pathlib.Path]
The preset name of the model, or the path to `config.json`, or the directory containing
`config.json`.
Returns
-------
config_json_path : pathlib.Path
The path points to config.json.
"""
from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel
MODEL_PRESETS,
)

if isinstance(config, str) and config in MODEL_PRESETS:
content = MODEL_PRESETS[config]
temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with
suffix=".json",
delete=False,
)
logger.info("%s preset model configuration: %s", FOUND, temp_file.name)
config_path = Path(temp_file.name)
with config_path.open("w", encoding="utf-8") as config_file:
json.dump(content, config_file, indent=2)
else:
config_path = Path(config)
if not config_path.exists():
raise ValueError(f"{config_path} does not exist.")

if config_path.is_dir():
# search config.json under config_path
# search config.json under config path
config_json_path = config_path / "config.json"
if not config_json_path.exists():
raise ValueError(f"Fail to find config.json under {config_path}.")
Expand Down

0 comments on commit 1a79a53

Please sign in to comment.