Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JSON-FFI] Enable n generation and pass in json schema #2481

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 68 additions & 57 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
gen_cfg->max_tokens = request.max_tokens.value_or(default_gen_cfg->max_tokens);
gen_cfg->stop_strs = std::move(stop_strs);
gen_cfg->stop_token_ids = conv_template_.stop_token_ids;
gen_cfg->response_format = request.response_format.value_or(ResponseFormat());
gen_cfg->debug_config = request.debug_config.value_or(DebugConfig());

Result<GenerationConfig> res_gen_config = GenerationConfig::Validate(GenerationConfig(gen_cfg));
Expand All @@ -120,12 +121,26 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
}

Request engine_request(request_id, inputs, res_gen_config.Unwrap());

// setup request state
RequestState rstate;
rstate.model = request.model.value_or("");
rstate.streamer.reserve(gen_cfg->n);
for (int i = 0; i < gen_cfg->n; ++i) {
rstate.streamer.push_back(TextStreamer(tokenizer_));
}
request_map_[request_id] = std::move(rstate);

this->engine_->AddRequest(engine_request);
return true;
}

bool JSONFFIEngine::Abort(std::string request_id) {
this->engine_->AbortRequest(request_id);
auto it = request_map_.find(request_id);
if (it != request_map_.end()) {
request_map_.erase(it);
}
return true;
}

Expand Down Expand Up @@ -187,10 +202,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
this->conv_template_ = conv_template.Unwrap();
this->model_config_ = ModelConfig::FromJSON(
json::Lookup<picojson::object>(model_config_json_unwrapped, "model_config"));

// Create streamer.
// Todo(mlc-team): Create one streamer for each request, instead of a global one.
this->streamer_ = TextStreamer(Tokenizer::FromPath(engine_config->model));
this->tokenizer_ = Tokenizer::FromPath(engine_config->model);
}

void Unload() { this->engine_->Unload(); }
Expand All @@ -202,23 +214,20 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); }

String GetResponseFromStreamOutput(Array<RequestStreamOutput> delta_outputs) {
std::unordered_map<std::string, std::vector<ChatCompletionStreamResponseChoice>> response_map;
std::vector<picojson::value> request_final_usage_messages;
std::string model = "json_ffi";

picojson::array json_response_arr;
for (const auto& delta_output : delta_outputs) {
std::string request_id = delta_output->request_id;
if (response_map.find(request_id) == response_map.end()) {
response_map[request_id] = std::vector<ChatCompletionStreamResponseChoice>();
}
auto request_state_it = request_map_.find(request_id);
if (request_state_it == request_map_.end()) continue;
RequestState& rstate = request_state_it->second;

// build the final usage messages
// invariant, we can always let other messages to come first
// then the final usage messages, as final usage is always last
if (delta_output->request_final_usage_json_str.defined()) {
ChatCompletionStreamResponse response;
response.id = request_id;
response.model = model;
response.model = rstate.model;
response.system_fingerprint = "";
std::string usage_json_str = delta_output->request_final_usage_json_str.value();
picojson::value usage_json;
Expand All @@ -228,59 +237,61 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
} else {
response.usage = usage_json;
}
request_final_usage_messages.push_back(picojson::value(response.AsJSON()));
json_response_arr.push_back(picojson::value(response.AsJSON()));
request_map_.erase(request_state_it);
continue;
}
ICHECK_NE(delta_output->group_finish_reason.size(), 0);
ChatCompletionStreamResponseChoice choice;

if (delta_output->group_finish_reason.size() != 1) {
// Only support n = 1 in ChatCompletionStreamResponse for now
this->err_ += "Group finish reason should have exactly one element";
}
Optional<String> finish_reason = delta_output->group_finish_reason[0];
if (finish_reason.defined()) {
if (finish_reason.value() == "stop") {
choice.finish_reason = FinishReason::stop;
} else if (finish_reason.value() == "length") {
choice.finish_reason = FinishReason::length;
} else if (finish_reason.value() == "tool_calls") {
choice.finish_reason = FinishReason::tool_calls;
} else if (finish_reason.value() == "error") {
choice.finish_reason = FinishReason::error;
}
} else {
choice.finish_reason = std::nullopt;
}

choice.index = response_map[request_id].size();

ChatCompletionMessage delta;
// Size of delta_output->group_delta_token_ids Array should be 1
IntTuple delta_token_ids = delta_output->group_delta_token_ids[0];
std::vector<int32_t> delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end());
delta.content = this->streamer_->Put(delta_token_ids_vec);
delta.role = "assistant";
ICHECK_EQ(delta_output->group_delta_token_ids.size(),
delta_output->group_finish_reason.size());
ICHECK_EQ(delta_output->group_delta_token_ids.size(), rstate.streamer.size());

choice.delta = delta;

response_map[request_id].push_back(choice);
}

picojson::array response_arr;
for (const auto& [request_id, choices] : response_map) {
if (choices.size() == 0) continue;
ChatCompletionStreamResponse response;
response.id = request_id;
response.choices = choices;
response.model = "json_ffi"; // TODO: Return model name from engine (or from args)
response.model = rstate.model;
response.system_fingerprint = "";
response_arr.push_back(picojson::value(response.AsJSON()));
}
for (auto&& item : request_final_usage_messages) {
response_arr.emplace_back(std::move(item));

for (size_t i = 0; i < delta_output->group_finish_reason.size(); ++i) {
// choice
ChatCompletionStreamResponseChoice choice;
Optional<String> finish_reason = delta_output->group_finish_reason[i];
if (finish_reason.defined()) {
if (finish_reason.value() == "stop") {
choice.finish_reason = FinishReason::stop;
} else if (finish_reason.value() == "length") {
choice.finish_reason = FinishReason::length;
} else if (finish_reason.value() == "tool_calls") {
choice.finish_reason = FinishReason::tool_calls;
} else if (finish_reason.value() == "error") {
choice.finish_reason = FinishReason::error;
}
} else {
choice.finish_reason = std::nullopt;
}
choice.index = static_cast<int>(i);
ChatCompletionMessage delta;
// Size of delta_output->group_delta_token_ids Array should be 1
const IntTuple& delta_token_ids = delta_output->group_delta_token_ids[i];
std::vector<int32_t> delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end());
std::string content = rstate.streamer[i]->Put(delta_token_ids_vec);
if (finish_reason.defined()) {
content += rstate.streamer[i]->Finish();
}
if (!content.empty()) {
delta.content = content;
}
delta.role = "assistant";
choice.delta = delta;
if (!choice.delta.content.IsNull() || choice.finish_reason.has_value()) {
response.choices.push_back(choice);
}
}
// if it is not the usage block, choices cannot be empty
if (!response.choices.empty()) {
json_response_arr.push_back(picojson::value(response.AsJSON()));
}
}
return picojson::value(response_arr).serialize();
return picojson::value(json_response_arr).serialize();
}
};

Expand Down
17 changes: 16 additions & 1 deletion cpp/json_ffi/json_ffi_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,29 @@ class JSONFFIEngine {
void ExitBackgroundLoop();

protected:
/*! \brief local request state entry, one per reply stream. */
struct RequestState {
/*! \brief model to fill in reply. */
std::string model;
/*! \brief text streamer for each stream */
std::vector<TextStreamer> streamer;
};

std::unique_ptr<ThreadedEngine> engine_;
std::string err_;
PackedFunc request_stream_callback_;
TextStreamer streamer_; // TODO: Support "n", and support different streamers for each request
// tokenizer
Tokenizer tokenizer_;
// conversation template
Conversation conv_template_;
// generation config
GenerationConfig default_generation_config_;
// model config
ModelConfig model_config_;
// local device
DLDevice device_;
// request state map
std::unordered_map<String, RequestState> request_map_;
};

} // namespace json_ffi
Expand Down
18 changes: 18 additions & 0 deletions cpp/json_ffi/openai_api_protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ Result<ChatCompletionRequest> ChatCompletionRequest::FromJSON(const std::string&
return TResult::Error(max_tokens_res.UnwrapErr());
}
request.max_tokens = max_tokens_res.Unwrap();
// n
Result<int64_t> n_res = json::LookupOrDefaultWithResultReturn<int64_t>(json_obj, "n", 1);
if (n_res.IsErr()) {
return TResult::Error(n_res.UnwrapErr());
}
request.n = n_res.Unwrap();
// frequency_penalty
Result<std::optional<double>> frequency_penalty_res =
json::LookupOptionalWithResultReturn<double>(json_obj, "frequency_penalty");
Expand Down Expand Up @@ -387,6 +393,18 @@ Result<ChatCompletionRequest> ChatCompletionRequest::FromJSON(const std::string&
request.tools = tools;
}

// response format
std::optional<picojson::object> response_format_obj =
json::LookupOptional<picojson::object>(json_obj, "response_format");
if (response_format_obj.has_value()) {
Result<ResponseFormat> response_format_res =
ResponseFormat::FromJSON(response_format_obj.value());
if (response_format_res.IsErr()) {
return TResult::Error(response_format_res.UnwrapErr());
}
request.response_format = response_format_res.Unwrap();
}

// debug_config
Result<std::optional<picojson::object>> debug_config_opt_res =
json::LookupOptionalWithResultReturn<picojson::object>(json_obj, "debug_config");
Expand Down
17 changes: 6 additions & 11 deletions cpp/json_ffi/openai_api_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ namespace mlc {
namespace llm {
namespace json_ffi {

using serve::DebugConfig;
using serve::ResponseFormat;

enum class Type { text, json_object, function };
enum class FinishReason { stop, length, tool_calls, error };

inline std::string generate_uuid_string(size_t length) {
inline std::string GenerateUUID(size_t length) {
auto randchar = []() -> char {
const char charset[] =
"0123456789"
Expand Down Expand Up @@ -71,7 +74,7 @@ class ChatFunctionCall {

class ChatToolCall {
public:
std::string id = "call_" + generate_uuid_string(8);
std::string id = "call_" + GenerateUUID(8);
Type type = Type::function;
ChatFunctionCall function;

Expand Down Expand Up @@ -122,14 +125,6 @@ class ChatCompletionMessage {
picojson::object AsJSON() const;
};

class RequestResponseFormat {
public:
Type type = Type::text;
std::optional<std::string> json_schema = std::nullopt;
};

using serve::DebugConfig;

class ChatCompletionRequest {
public:
std::vector<ChatCompletionMessage> messages;
Expand All @@ -150,7 +145,7 @@ class ChatCompletionRequest {
std::optional<std::string> tool_choice = std::nullopt;
std::optional<std::string> user = std::nullopt;
bool ignore_eos = false;
// RequestResponseFormat response_format; //TODO: implement this
std::optional<ResponseFormat> response_format = std::nullopt;
std::optional<DebugConfig> debug_config = std::nullopt;

/*! \brief Parse and create a ChatCompletionRequest instance from the given JSON string. */
Expand Down
41 changes: 33 additions & 8 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,34 @@ uint64_t TotalDetectGlobalMemory(DLDevice device) {
return gpu_size_bytes;
}

/****************** ResponseFormat ******************/

Result<ResponseFormat> ResponseFormat::FromJSON(const picojson::object& config) {
using TResult = Result<ResponseFormat>;
ResponseFormat res;
res.type = json::LookupOrDefault<std::string>(config, "type", "text");

std::optional<std::string> schema = json::LookupOptional<std::string>(config, "schema");
if (schema.has_value()) {
res.schema = schema.value();
}

if (res.type != "text" && res.type != "function" && res.type != "json_object") {
return TResult::Error("Uknonwn response_format type " + res.type);
}

return TResult::Ok(res);
}

picojson::object ResponseFormat::AsJSON() const {
picojson::object config;
config["type"] = picojson::value(type);
if (schema.defined()) {
config["schema"] = picojson::value(schema.value().operator std::string());
}
return config;
}

/****************** DebugConfig ******************/

Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
Expand Down Expand Up @@ -178,15 +206,12 @@ Result<GenerationConfig> GenerationConfig::FromJSON(const picojson::object& conf
std::optional<picojson::object> response_format_obj =
json::LookupOptional<picojson::object>(config, "response_format");
if (response_format_obj.has_value()) {
ResponseFormat response_format;
response_format.type = json::LookupOrDefault<std::string>(response_format_obj.value(), "type",
response_format.type);
std::optional<std::string> schema =
json::LookupOptional<std::string>(response_format_obj.value(), "schema");
if (schema.has_value()) {
response_format.schema = schema.value();
Result<ResponseFormat> response_format_res =
ResponseFormat::FromJSON(response_format_obj.value());
if (response_format_res.IsErr()) {
return TResult::Error(response_format_res.UnwrapErr());
}
n->response_format = response_format;
n->response_format = response_format_res.Unwrap();
} else {
n->response_format = default_config->response_format;
}
Expand Down
11 changes: 11 additions & 0 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ using namespace tvm::runtime;
struct ResponseFormat {
String type = "text";
Optional<String> schema = NullOpt;
/*!
* \brief Create debug config from JSON.
* \param config_json The json string for generation config
* \returns The converted result.
*/
static Result<ResponseFormat> FromJSON(const picojson::object& config_json);

/**
* \return serialized json value of the config.
*/
picojson::object AsJSON() const;
};

enum class SpecialRequestKind : int {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// NOTE: This example is still work in progress
//
// This is a minimum example App to interact with MLC Engine
// This app is mainly created with minimalism in mind for
// example and quick testing purposes.
Expand Down
Loading
Loading