Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chat] Support chat completion config override #2412

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
gen_cfg->logprobs = request.logprobs;
gen_cfg->top_logprobs = request.top_logprobs;
gen_cfg->logit_bias = request.logit_bias.value_or(default_gen_cfg->logit_bias);
gen_cfg->seed = request.seed.value_or(default_gen_cfg->seed);
gen_cfg->max_tokens = request.seed.value_or(default_gen_cfg->max_tokens);
gen_cfg->seed = request.seed.value_or(std::random_device{}());
gen_cfg->max_tokens = request.max_tokens.value_or(default_gen_cfg->max_tokens);
gen_cfg->stop_strs = std::move(stop_strs);
gen_cfg->stop_token_ids = conv_template_.stop_token_ids;
gen_cfg->debug_config = request.debug_config.value_or(DebugConfig());
Expand Down
41 changes: 39 additions & 2 deletions cpp/json_ffi/openai_api_protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,29 +295,66 @@ Result<ChatCompletionRequest> ChatCompletionRequest::FromJSON(const std::string&
}
request.model = model_res.Unwrap();

// temperature
Result<std::optional<double>> temperature_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "temperature");
if (temperature_res.IsErr()) {
return TResult::Error(temperature_res.UnwrapErr());
}
request.temperature = temperature_res.Unwrap();
// top_p
Result<std::optional<double>> top_p_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "top_p");
if (top_p_res.IsErr()) {
return TResult::Error(top_p_res.UnwrapErr());
}
request.top_p = top_p_res.Unwrap();
// max_tokens
Result<std::optional<int64_t>> max_tokens_res =
json::LookupOptionalWithResultReturn<int64_t>(json_obj, "max_tokens");
if (max_tokens_res.IsErr()) {
return TResult::Error(max_tokens_res.UnwrapErr());
}
request.max_tokens = max_tokens_res.Unwrap();

// frequency_penalty
Result<std::optional<double>> frequency_penalty_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "frequency_penalty");
if (frequency_penalty_res.IsErr()) {
return TResult::Error(frequency_penalty_res.UnwrapErr());
}
request.frequency_penalty = frequency_penalty_res.Unwrap();

// presence_penalty
Result<std::optional<double>> presence_penalty_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "presence_penalty");
if (presence_penalty_res.IsErr()) {
return TResult::Error(presence_penalty_res.UnwrapErr());
}
request.presence_penalty = presence_penalty_res.Unwrap();
// seed
Result<std::optional<int64_t>> seed_res =
json::LookupOptionalWithResultReturn<int64_t>(json_obj, "seed");
if (seed_res.IsErr()) {
return TResult::Error(seed_res.UnwrapErr());
}
request.seed = seed_res.Unwrap();

// stop strings
Result<std::optional<picojson::array>> stop_strs_res =
json::LookupOptionalWithResultReturn<picojson::array>(json_obj, "stop");
if (stop_strs_res.IsErr()) {
return TResult::Error(stop_strs_res.UnwrapErr());
}
std::optional<picojson::array> stop_strs = stop_strs_res.Unwrap();
if (stop_strs.has_value()) {
std::vector<std::string> stop;
for (picojson::value stop_str_value : stop_strs.value()) {
if (!stop_str_value.is<std::string>()) {
return TResult::Error("One given value in field \"stop\" is not a string.");
}
stop.push_back(stop_str_value.get<std::string>());
}
request.stop = std::move(stop);
}

// tool_choice
Result<std::string> tool_choice_res =
Expand Down
63 changes: 55 additions & 8 deletions python/mlc_llm/interface/chat.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,66 @@
"""Python entrypoint of chat."""
from typing import List, Optional

import dataclasses
from typing import Dict, List, Optional, Union

from prompt_toolkit import prompt as get_prompt # pylint: disable=import-error
from prompt_toolkit.key_binding import KeyBindings # pylint: disable=import-error

from mlc_llm.json_ffi import JSONFFIEngine
from mlc_llm.support import argparse
from mlc_llm.support.config import ConfigOverrideBase


@dataclasses.dataclass
class ChatCompletionOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes
"""Flags for overriding chat completions."""

temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None

@staticmethod
def from_str(source: str) -> "ChatCompletionOverride":
"""Parse model config override values from a string."""
parser = argparse.ArgumentParser(description="chat completion override values")
parser.add_argument("--temperature", type=float, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--frequency_penalty", type=float, default=None)
parser.add_argument("--presence_penalty", type=float, default=None)
parser.add_argument("--max_tokens", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--stop", type=str, default=None)
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
return ChatCompletionOverride(
temperature=results.temperature,
top_p=results.top_p,
frequency_penalty=results.frequency_penalty,
presence_penalty=results.presence_penalty,
max_tokens=results.max_tokens,
seed=results.seed,
stop=results.stop.split(",") if results.stop is not None else None,
)


class ChatState:
"""Helper class to manage chat state"""

history: List[dict]
history: List[Dict]
history_begin: int
# kwargs passed to completions
overrides: dict
overrides: ChatCompletionOverride
# we use JSON ffi engine to ensure broader coverage
engine: JSONFFIEngine

def __init__(self, engine):
self.engine = engine
self.history = []
self.history_window_begin = 0
self.overrides = {}
self.overrides = ChatCompletionOverride()

def process_system_prompts(self):
"""Process system prompts"""
Expand All @@ -45,7 +84,9 @@ def generate(self, prompt: str):
finish_reason_length = False
messages = self.history[self.history_window_begin :]
for response in self.engine.chat.completions.create(
messages=messages, stream=True, **self.overrides
messages=messages,
stream=True,
**dataclasses.asdict(self.overrides),
):
for choice in response.choices:
assert choice.delta.role == "assistant"
Expand Down Expand Up @@ -90,6 +131,9 @@ def _print_help_str():
/stats print out stats of last request (token/sec)
/metrics print out full engine metrics
/reset restart a fresh chat
/set [overrides] override settings in the generation config. For example,
`/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`
Note: Separate stop words in the `stop` option with commas (,).
Multi-line input: Use escape+enter to start a new line.
"""
print(help_str)
Expand Down Expand Up @@ -132,16 +176,19 @@ def chat(
key_bindings=kb,
multiline=True,
)
if prompt[:6] == "/stats":
if prompt[:4] == "/set":
overrides = ChatCompletionOverride.from_str(prompt.split()[1])
for key, value in dataclasses.asdict(overrides).items():
if value is not None:
setattr(chat_state.overrides, key, value)
elif prompt[:6] == "/stats":
print(chat_state.stats(), flush=True)
elif prompt[:8] == "/metrics":
print(chat_state.metrics(), flush=True)
elif prompt[:6] == "/reset":
chat_state.reset_chat()
elif prompt[:5] == "/exit":
break
# elif prompt[:6] == "/stats":
# print(cm.stats(), flush=True)
elif prompt[:5] == "/help":
_print_help_str()
else:
Expand Down
Loading