Skip to content

Commit

Permalink
[Chat] Support chat completion config override (#2412)
Browse files Browse the repository at this point in the history
This PR supports chat CLI with arguments override.

Right now, arguments supported are: `top_p`, `temperature`,
`presence_penalty`, `frequency_penalty`, `max_tokens`, `seed`,
`stop`.

This PR adds the corresponding support to the ChatCompletion request
parsing for JSONFFIEngine.
  • Loading branch information
MasterJH5574 authored May 24, 2024
1 parent 7eba612 commit 905620c
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 12 deletions.
4 changes: 2 additions & 2 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
gen_cfg->logprobs = request.logprobs;
gen_cfg->top_logprobs = request.top_logprobs;
gen_cfg->logit_bias = request.logit_bias.value_or(default_gen_cfg->logit_bias);
gen_cfg->seed = request.seed.value_or(default_gen_cfg->seed);
gen_cfg->max_tokens = request.seed.value_or(default_gen_cfg->max_tokens);
gen_cfg->seed = request.seed.value_or(std::random_device{}());
gen_cfg->max_tokens = request.max_tokens.value_or(default_gen_cfg->max_tokens);
gen_cfg->stop_strs = std::move(stop_strs);
gen_cfg->stop_token_ids = conv_template_.stop_token_ids;
gen_cfg->debug_config = request.debug_config.value_or(DebugConfig());
Expand Down
41 changes: 39 additions & 2 deletions cpp/json_ffi/openai_api_protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,29 +295,66 @@ Result<ChatCompletionRequest> ChatCompletionRequest::FromJSON(const std::string&
}
request.model = model_res.Unwrap();

// temperature
Result<std::optional<double>> temperature_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "temperature");
if (temperature_res.IsErr()) {
return TResult::Error(temperature_res.UnwrapErr());
}
request.temperature = temperature_res.Unwrap();
// top_p
Result<std::optional<double>> top_p_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "top_p");
if (top_p_res.IsErr()) {
return TResult::Error(top_p_res.UnwrapErr());
}
request.top_p = top_p_res.Unwrap();
// max_tokens
Result<std::optional<int64_t>> max_tokens_res =
json::LookupOptionalWithResultReturn<int64_t>(json_obj, "max_tokens");
if (max_tokens_res.IsErr()) {
return TResult::Error(max_tokens_res.UnwrapErr());
}
request.max_tokens = max_tokens_res.Unwrap();

// frequency_penalty
Result<std::optional<double>> frequency_penalty_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "frequency_penalty");
if (frequency_penalty_res.IsErr()) {
return TResult::Error(frequency_penalty_res.UnwrapErr());
}
request.frequency_penalty = frequency_penalty_res.Unwrap();

// presence_penalty
Result<std::optional<double>> presence_penalty_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "presence_penalty");
if (presence_penalty_res.IsErr()) {
return TResult::Error(presence_penalty_res.UnwrapErr());
}
request.presence_penalty = presence_penalty_res.Unwrap();
// seed
Result<std::optional<int64_t>> seed_res =
json::LookupOptionalWithResultReturn<int64_t>(json_obj, "seed");
if (seed_res.IsErr()) {
return TResult::Error(seed_res.UnwrapErr());
}
request.seed = seed_res.Unwrap();

// stop strings
Result<std::optional<picojson::array>> stop_strs_res =
json::LookupOptionalWithResultReturn<picojson::array>(json_obj, "stop");
if (stop_strs_res.IsErr()) {
return TResult::Error(stop_strs_res.UnwrapErr());
}
std::optional<picojson::array> stop_strs = stop_strs_res.Unwrap();
if (stop_strs.has_value()) {
std::vector<std::string> stop;
for (picojson::value stop_str_value : stop_strs.value()) {
if (!stop_str_value.is<std::string>()) {
return TResult::Error("One given value in field \"stop\" is not a string.");
}
stop.push_back(stop_str_value.get<std::string>());
}
request.stop = std::move(stop);
}

// tool_choice
Result<std::string> tool_choice_res =
Expand Down
63 changes: 55 additions & 8 deletions python/mlc_llm/interface/chat.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,66 @@
"""Python entrypoint of chat."""
from typing import List, Optional

import dataclasses
from typing import Dict, List, Optional, Union

from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error
from prompt_toolkit.key_binding import KeyBindings # pylint: disable=import-error

from mlc_llm.json_ffi import JSONFFIEngine
from mlc_llm.support import argparse
from mlc_llm.support.config import ConfigOverrideBase


@dataclasses.dataclass
class ChatCompletionOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes
"""Flags for overriding chat completions."""

temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None

@staticmethod
def from_str(source: str) -> "ChatCompletionOverride":
"""Parse model config override values from a string."""
parser = argparse.ArgumentParser(description="chat completion override values")
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--frequency_penalty", type=float, default=None)
parser.add_argument("--presence_penalty", type=float, default=None)
parser.add_argument("--max_tokens", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--stop", type=str, default=None)
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return ChatCompletionOverride(
temperature=results.temperature,
top_p=results.top_p,
frequency_penalty=results.frequency_penalty,
presence_penalty=results.presence_penalty,
max_tokens=results.max_tokens,
seed=results.seed,
stop=results.stop.split(",") if results.stop is not None else None,
)


class ChatState:
"""Helper class to manage chat state"""

history: List[dict]
history: List[Dict]
history_begin: int
# kwargs passed to completions
overrides: dict
overrides: ChatCompletionOverride
# we use JSON ffi engine to ensure broader coverage
engine: JSONFFIEngine

def __init__(self, engine):
self.engine = engine
self.history = []
self.history_window_begin = 0
self.overrides = {}
self.overrides = ChatCompletionOverride()

def process_system_prompts(self):
"""Process system prompts"""
Expand All @@ -45,7 +84,9 @@ def generate(self, prompt: str):
finish_reason_length = False
messages = self.history[self.history_window_begin :]
for response in self.engine.chat.completions.create(
messages=messages, stream=True, **self.overrides
messages=messages,
stream=True,
**dataclasses.asdict(self.overrides),
):
for choice in response.choices:
assert choice.delta.role == "assistant"
Expand Down Expand Up @@ -90,6 +131,9 @@ def _print_help_str():
/stats print out stats of last request (token/sec)
/metrics print out full engine metrics
/reset restart a fresh chat
/set [overrides] override settings in the generation config. For example,
`/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`
Note: Separate stop words in the `stop` option with commas (,).
Multi-line input: Use escape+enter to start a new line.
"""
print(help_str)
Expand Down Expand Up @@ -132,16 +176,19 @@ def chat(
key_bindings=kb,
multiline=True,
)
if prompt[:6] == "/stats":
if prompt[:4] == "/set":
overrides = ChatCompletionOverride.from_str(prompt.split()[1])
for key, value in dataclasses.asdict(overrides).items():
if value is not None:
setattr(chat_state.overrides, key, value)
elif prompt[:6] == "/stats":
print(chat_state.stats(), flush=True)
elif prompt[:8] == "/metrics":
print(chat_state.metrics(), flush=True)
elif prompt[:6] == "/reset":
chat_state.reset_chat()
elif prompt[:5] == "/exit":
break
# elif prompt[:6] == "/stats":
# print(cm.stats(), flush=True)
elif prompt[:5] == "/help":
_print_help_str()
else:
Expand Down

0 comments on commit 905620c

Please sign in to comment.