Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Migrate CLI to use the new Engine #2375

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions python/mlc_llm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main():
parser.add_argument(
"subcommand",
type=str,
choices=["compile", "convert_weight", "gen_config", "chat", "serve", "bench", "package"],
choices=["compile", "convert_weight", "gen_config", "chat", "serve", "package"],
help="Subcommand to to run. (choices: %(choices)s)",
)
parsed = parser.parse_args(sys.argv[1:2])
Expand All @@ -38,10 +38,6 @@ def main():
elif parsed.subcommand == "serve":
from mlc_llm.cli import serve as cli

cli.main(sys.argv[2:])
elif parsed.subcommand == "bench":
from mlc_llm.cli import bench as cli

cli.main(sys.argv[2:])
elif parsed.subcommand == "package":
from mlc_llm.cli import package as cli
Expand Down
63 changes: 0 additions & 63 deletions python/mlc_llm/cli/bench.py

This file was deleted.

16 changes: 1 addition & 15 deletions python/mlc_llm/cli/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Command line entrypoint of chat."""

from mlc_llm.help import HELP
from mlc_llm.interface.chat import ChatConfigOverride, chat
from mlc_llm.interface.chat import chat
from mlc_llm.support.argparse import ArgumentParser


Expand All @@ -14,24 +14,12 @@ def main(argv):
type=str,
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--opt",
type=str,
default="O2",
help=HELP["opt"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--device",
type=str,
default="auto",
help=HELP["device_deploy"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--overrides",
type=ChatConfigOverride.from_str,
default="",
help=HELP["chatconfig_overrides"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--model-lib",
type=str,
Expand All @@ -42,7 +30,5 @@ def main(argv):
chat(
model=parsed.model,
device=parsed.device,
opt=parsed.opt,
overrides=parsed.overrides,
model_lib=parsed.model_lib,
)
29 changes: 0 additions & 29 deletions python/mlc_llm/interface/bench.py

This file was deleted.

183 changes: 72 additions & 111 deletions python/mlc_llm/interface/chat.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,76 @@
"""Python entrypoint of chat."""

import dataclasses
from typing import List, Optional, Union
from typing import List, Optional

from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error
from prompt_toolkit.key_binding import KeyBindings # pylint: disable=import-error

from mlc_llm.callback import StreamToStdout
from mlc_llm.chat_module import ChatConfig, ChatModule, GenerationConfig
from mlc_llm.support import argparse
from mlc_llm.support.config import ConfigOverrideBase


@dataclasses.dataclass
class ChatConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes
"""Flags for overriding chat config."""

conv_template: Optional[str] = None
context_window_size: Optional[int] = None
sliding_window_size: Optional[int] = None
prefill_chunk_size: Optional[int] = None
attention_sink_size: Optional[int] = None
max_batch_size: Optional[int] = None
tensor_parallel_shards: Optional[int] = None

@staticmethod
def from_str(source: str) -> "ChatConfigOverride":
"""Parse model config override values from a string."""
parser = argparse.ArgumentParser(description="chat config override values")
parser.add_argument("--conv_template", type=str, default=None)
parser.add_argument("--tensor_parallel_shards", type=int, default=None)
parser.add_argument("--context_window_size", type=int, default=None)
parser.add_argument("--sliding_window_size", type=int, default=None)
parser.add_argument("--prefill_chunk_size", type=int, default=None)
parser.add_argument("--attention_sink_size", type=int, default=None)
parser.add_argument("--max_batch_size", type=int, default=None)

results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return ChatConfigOverride(
conv_template=results.conv_template,
tensor_parallel_shards=results.tensor_parallel_shards,
context_window_size=results.context_window_size,
sliding_window_size=results.sliding_window_size,
prefill_chunk_size=results.prefill_chunk_size,
attention_sink_size=results.attention_sink_size,
max_batch_size=results.max_batch_size,
)


@dataclasses.dataclass
class GenerationConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes
"""Flags for overriding generation config."""

temperature: Optional[float] = None
repetition_penalty: Optional[float] = None
top_p: Optional[float] = None
mean_gen_len: Optional[int] = None
max_gen_len: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
n: Optional[int] = None # pylint: disable=invalid-name
stop: Optional[Union[str, List[str]]] = None

@staticmethod
def from_str(source: str) -> "GenerationConfigOverride":
"""Parse model config override values from a string."""
parser = argparse.ArgumentParser(description="generation config override values")
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--repetition_penalty", type=float, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--mean_gen_len", type=int, default=None)
parser.add_argument("--max_gen_len", type=int, default=None)
parser.add_argument("--presence_penalty", type=float, default=None)
parser.add_argument("--frequency_penalty", type=float, default=None)
parser.add_argument("--n", type=int, default=None)
parser.add_argument("--stop", type=str, default=None)
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return GenerationConfigOverride(
temperature=results.temperature,
repetition_penalty=results.repetition_penalty,
top_p=results.top_p,
mean_gen_len=results.mean_gen_len,
max_gen_len=results.max_gen_len,
presence_penalty=results.presence_penalty,
frequency_penalty=results.frequency_penalty,
n=results.n,
stop=results.stop.split(",") if results.stop is not None else None,
)


from mlc_llm.json_ffi import JSONFFIEngine


class ChatState:
"""Helper class to manage chat state"""

history: List[dict]
history_begin: int
# we use JSON ffi engine to ensure broader coverage
engine: JSONFFIEngine

def __init__(self, engine):
self.engine = engine
self.history = []
self.history_window_begin = 0

def process_system_prompts(self):
"""Process system prompts"""
# TODO(mlc-team): possibly leverage debug option
# pass a simple prompt to warm up
for _ in self.engine.chat.completions.create(
messages=[{"role": "user", "content": "hello"}], max_tokens=1, stream=True
):
pass

def slide_history(self):
"""Slide history to fit into context window"""
history_window_size = len(self.history) - self.history_window_begin
assert history_window_size % 2 == 0
self.history_window_begin += (history_window_size // 4) * 2

def generate(self, prompt: str):
"""Run one generatiohn with the prompt"""
self.history.append({"role": "user", "content": prompt})
output_text = ""
finish_reason_length = False
messages = self.history[self.history_window_begin :]
for response in self.engine.chat.completions.create(messages=messages, stream=True):
for choice in response.choices:
assert choice.delta.role == "assistant"
if isinstance(choice.delta.content, str):
output_text += choice.delta.content
print(choice.delta.content, end="", flush=True)
if choice.finish_reason == "length":
finish_reason_length = True
if finish_reason_length:
print(" [output truncated due to context length limit...]")
# print additional \n when generation ends
print()
# record the history
self.history.append({"role": "assistant", "content": output_text})
if finish_reason_length:
self.slide_history()

def reset_chat(self):
"""Reset the chat history"""
self.history = []
self.history_window_begin = 0


# TODO(mlc-team): add back support for stats
def _print_help_str():
help_str = """You can use the following special commands:
/help print the special commands
/exit quit the cli
/stats print out the latest stats (token/sec)
/reset restart a fresh chat
/set [overrides] override settings in the generation config. For example,
`/set temperature=0.5;max_gen_len=100;stop=end,stop`
Note: Separate stop words in the `stop` option with commas (,).
Multi-line input: Use escape+enter to start a new line.
"""
print(help_str)
Expand All @@ -120,45 +93,33 @@ def _(event):
def chat(
model: str,
device: str,
opt: str,
overrides: ChatConfigOverride,
model_lib: Optional[str],
):
"""chat with a model."""
# Set up chat config and generate config
config = ChatConfig(opt=opt)
generate_config = GenerationConfig()
# Apply overrides
config = overrides.apply(config)

# Set up ChatModule
cm = ChatModule(model, device, chat_config=config, model_lib=model_lib)
engine = JSONFFIEngine(model, device, model_lib=model_lib, mode="interactive")
_print_help_str()
cm._process_system_prompts() # pylint: disable=protected-access

chat_state = ChatState(engine)
chat_state.process_system_prompts() # pylint: disable=protected-access

# Multi-line input support: set escape+enter as start a new line
kb = _set_up_key_bindings()

while True:
prompt = get_prompt(
f"{cm._get_role_0()}: ", # pylint: disable=protected-access
">>> ", # pylint: disable=protected-access
key_bindings=kb,
multiline=True,
)
if prompt[:6] == "/reset":
cm.reset_chat()
chat_state.reset_chat()
elif prompt[:5] == "/exit":
break
elif prompt[:6] == "/stats":
print(cm.stats(), flush=True)
elif prompt[:4] == "/set":
gen_config_overrides = GenerationConfigOverride.from_str(prompt.split()[1])
generate_config = gen_config_overrides.apply(generate_config)
# elif prompt[:6] == "/stats":
# print(cm.stats(), flush=True)
elif prompt[:5] == "/help":
_print_help_str()
else:
print(f"{cm._get_role_1()}: ") # pylint: disable=protected-access
cm.generate(
prompt,
progress_callback=StreamToStdout(callback_interval=2),
generation_config=generate_config,
)
chat_state.generate(prompt)
4 changes: 2 additions & 2 deletions python/mlc_llm/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def _find_kv_cache_bytes(model: nn.Module, model_config) -> int:
if (
args.quantization.kind == "ft-quant"
and hasattr(model_config, "tensor_parallel_shards")
and model_config.tensor_parallel_shards > 1
and model_config.tensor_parallel_shards > 1 # type: ignore
):
raise NotImplementedError
if (
hasattr(args.quantization, "linear_weight_layout")
and args.quantization.linear_weight_layout == "KN"
and hasattr(model_config, "tensor_parallel_shards")
and model_config.tensor_parallel_shards > 1
and model_config.tensor_parallel_shards > 1 # type: ignore
):
raise NotImplementedError(
"KN layout (q3f16_0 and q4f16_0) is not supported for tensor parallelism"
Expand Down
Loading
Loading