Skip to content

Commit

Permalink
Rename --config to --model and Consolidate CLI Messages
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Nov 13, 2023
1 parent ab2a05b commit 5d15e99
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 112 deletions.
59 changes: 15 additions & 44 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from typing import Union

from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
HELP,
MODELS,
QUANTIZATION,
OptimizationFlags,
compile,
)

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_target import detect_target_and_host

Expand All @@ -26,17 +28,11 @@
def main():
"""Parse command line argumennts and call `mlc_llm.compiler.compile`."""

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")

def _parse_output(path: Union[str, Path]) -> Path:
path = Path(path)
parent = path.parent
if not parent.is_dir():
raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}")
raise ValueError(f"Directory does not exist: {parent}")
return path

def _check_prefix_symbols(prefix: str) -> str:
Expand All @@ -48,88 +44,63 @@ def _check_prefix_symbols(prefix: str) -> str:
"numbers (0-9), alphabets (A-Z, a-z) and underscore (_)."
)

parser = argparse.ArgumentParser("MLC LLM Compiler")
parser = ArgumentParser("MLC LLM Compiler")
parser.add_argument(
"--config",
type=_parse_config,
"--model",
type=detect_config,
required=True,
help="Path to config.json file or to the directory that contains config.json, which is "
"a HuggingFace standard that defines model architecture, for example, "
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json",
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=list(QUANTIZATION.keys()),
help="Quantization format.",
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file. "
"(default: %(default)s)",
help=HELP["model_type"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="The GPU device to compile the model to. If not set, it is inferred from locally "
"available GPUs. "
"(default: %(default)s)",
help=HELP["device_compile"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--host",
type=str,
default="auto",
help="The host CPU ISA to compile the model to. If not set, it is inferred from the "
"local CPU. (default: %(default)s)",
help=HELP["host"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--opt",
type=OptimizationFlags.from_str,
default="O2",
help="Optimization flags. MLC LLM maintains a predefined set of optimization flags, "
"denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, "
"and O3 represents extreme optimization that could potentially break the system. "
"Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. "
'--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0. '
"(default: %(default)s)",
help=HELP["opt"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--prefix-symbols",
type=str,
default="",
help='Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". '
"This is useful when compiling multiple models into a single library to avoid symbol "
"conflicts. Differet from objcopy, this takes no effect for shared library. "
'(default: "")',
help=HELP["prefix_symbols"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--max-sequence-length",
type=int,
default=None,
help="Option to override the maximum sequence length supported by the model. "
"An LLM is usually trained with a fixed maximum sequence length, which is usually "
"explicitly specified in model spec. By default, if this option is not set explicitly, "
"the maximum sequence length is determined by `max_sequence_length` or "
"`max_position_embeddings` in config.json, which can be inaccuate for some models.",
help=HELP["max_sequence_length"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help="The name of the output file. The suffix determines if the output file is a "
"shared library or objects. Available suffixes: "
"1) Linux: .so (shared), .tar (objects); "
"2) macOS: .dylib (shared), .tar (objects); "
"3) Windows: .dll (shared), .tar (objects); "
"4) Android, iOS: .tar (objects); "
"5) Web: .wasm (web assembly)",
help=HELP["output_compile"] + " (required)",
)
parsed = parser.parse_args()
target, build_func = detect_target_and_host(parsed.device, parsed.host)
Expand Down
60 changes: 23 additions & 37 deletions python/mlc_chat/cli/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from pathlib import Path
from typing import Union

from mlc_chat.compiler import MODELS, QUANTIZATION, convert_weight
from mlc_chat.compiler import HELP, MODELS, QUANTIZATION, convert_weight

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_target import detect_device
from ..support.auto_weight import detect_weight
Expand All @@ -21,12 +22,6 @@
def main():
"""Parse command line argumennts and apply quantization."""

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")

def _parse_source(path: Union[str, Path], config_path: Path) -> Path:
if path == "auto":
return config_path.parent
Expand All @@ -41,61 +36,52 @@ def _parse_output(path: Union[str, Path]) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path

parser = argparse.ArgumentParser("MLC AutoLLM Quantization Framework")
parser = ArgumentParser("MLC AutoLLM Quantization Framework")
parser.add_argument(
"--config",
type=_parse_config,
"--model",
type=detect_config,
required=True,
help="Path to config.json file or to the directory that contains config.json, which is "
"a HuggingFace standard that defines model architecture, for example, "
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json",
)
parser.add_argument(
"--source",
type=str,
default="auto",
help="The path to original model weight, infer from `config` if missing. "
"(default: %(default)s)",
)
parser.add_argument(
"--source-format",
type=str,
choices=["auto", "huggingface-torch", "huggingface-safetensor", "awq"],
default="auto",
help="The format of source model weight, infer from `config` if missing. "
"(default: %(default)s)",
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=list(QUANTIZATION.keys()),
help="Quantization format, for example `q4f16_1`.",
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file. "
"(default: %(default)s)",
help=HELP["model_type"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--device",
default="auto",
type=detect_device,
help="The device used to do quantization, for example, / `cuda:0`. "
"Detect from local environment if not specified. "
"(default: %(default)s)",
help=HELP["device_quantize"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--source",
type=str,
default="auto",
help=HELP["source"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--source-format",
type=str,
choices=["auto", "huggingface-torch", "huggingface-safetensor", "awq"],
default="auto",
help=HELP["source_format"] + ' (default: "%(default)s", choices: %(choices)s")',
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help="The output directory to save the quantized model weight, "
"will contain `params_shard_*.bin` and `ndarray-cache.json`.",
help=HELP["output_quantize"] + " (required)",
)

parsed = parser.parse_args()
Expand Down
40 changes: 11 additions & 29 deletions python/mlc_chat/cli/gen_mlc_chat_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Command line entrypoint of configuration generation."""
import argparse
import logging
from pathlib import Path
from typing import Union

from mlc_chat.compiler import CONV_TEMPLATES, MODELS, QUANTIZATION, gen_config
from mlc_chat.compiler import CONV_TEMPLATES, HELP, MODELS, QUANTIZATION, gen_config

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type

logging.basicConfig(
Expand All @@ -18,13 +18,7 @@

def main():
"""Parse command line argumennts and call `mlc_llm.compiler.gen_config`."""
parser = argparse.ArgumentParser("MLC LLM Configuration Generator")

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")
parser = ArgumentParser("MLC LLM Configuration Generator")

def _parse_output(path: Union[str, Path]) -> Path:
path = Path(path)
Expand All @@ -33,56 +27,44 @@ def _parse_output(path: Union[str, Path]) -> Path:
return path

parser.add_argument(
"--config",
type=_parse_config,
"--model",
type=detect_config,
required=True,
help="Path to config.json file or to the directory that contains config.json, which is "
"a HuggingFace standard that defines model architecture, for example, "
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json. "
"This `config.json` file is expected to colocate with other configurations, such as "
"tokenizer configuration and `generation_config.json`.",
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=list(QUANTIZATION.keys()),
help="Quantization format.",
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file. "
"(default: %(default)s)",
help=HELP["model_type"] + ' (default: "%(default)s", choices: %(choices)s)',
)
parser.add_argument(
"--conv-template",
type=str,
required=True,
choices=list(CONV_TEMPLATES),
help='Conversation template. It depends on how the model is tuned. Use "LM" for vanilla '
"base model",
help=HELP["conv_template"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--max-sequence-length",
type=int,
default=None,
help="Option to override the maximum sequence length supported by the model. "
"An LLM is usually trained with a fixed maximum sequence length, which is usually "
"explicitly specified in model spec. By default, if this option is not set explicitly, "
"the maximum sequence length is determined by `max_sequence_length` or "
"`max_position_embeddings` in config.json, which can be inaccuate for some models.",
help=HELP["max_sequence_length"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help="The output directory for generated configurations, including `mlc-chat-config.json`, "
"and tokenizer configuration.",
help=HELP["output_gen_mlc_chat_config"] + " (required)",
)
parsed = parser.parse_args()
model = detect_model_type(parsed.model_type, parsed.config)
Expand Down
1 change: 1 addition & 0 deletions python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .flags_model_config_override import ModelConfigOverride
from .flags_optimization import OptimizationFlags
from .gen_mlc_chat_config import CONV_TEMPLATES, gen_config
from .help import HELP
from .loader import LOADER, ExternMapping, HuggingFaceLoader, QuantizeMapping
from .model import MODEL_PRESETS, MODELS, Model
from .quantization import QUANTIZATION
3 changes: 3 additions & 0 deletions python/mlc_chat/compiler/gen_mlc_chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

FOUND = green("Found")
NOT_FOUND = red("Not found")
VERSION = "0.1.0"


@dataclasses.dataclass
class MLCChatConfig: # pylint: disable=too-many-instance-attributes
"""Arguments for `mlc_chat.compiler.gen_config`."""

version: str = VERSION

model_type: str = None
quantization: str = None
model_config: Dict[str, Any] = None
Expand Down
Loading

0 comments on commit 5d15e99

Please sign in to comment.