Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 145 additions & 136 deletions cpp/json_ffi/conv_template.cc

Large diffs are not rendered by default.

34 changes: 9 additions & 25 deletions cpp/json_ffi/conv_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <vector>

#include "../serve/data.h"
#include "../support/result.h"
#include "picojson.h"

using namespace mlc::llm::serve;
Expand Down Expand Up @@ -86,34 +87,17 @@ struct Conversation {
// Function call fields
// whether using function calling or not, helps check for output message format in API call
std::optional<std::string> function_string = std::nullopt;
std::optional<bool> use_function_calling = false;
bool use_function_calling = false;

Conversation();

/**
* @brief Checks the size of the separators vector.
* This function checks if the size of the separators vector is either 1 or 2.
* If the size is not 1 or 2, it throws an invalid_argument exception.
*/
static std::vector<std::string> CheckMessageSeps(std::vector<std::string>& seps);

/*!
* \brief Create the list of prompts from the messages based on the conversation template.
* When creation fails, errors are dumped to the input error string, and nullopt is returned.
*/
std::optional<std::vector<Data>> AsPrompt(std::string* err);

/*!
* \brief Create a Conversation instance from the given JSON object.
* When creation fails, errors are dumped to the input error string, and nullopt is returned.
*/
static std::optional<Conversation> FromJSON(const picojson::object& json, std::string* err);

/*!
* \brief Parse and create a Conversation instance from the given JSON string.
* When creation fails, errors are dumped to the input error string, and nullopt is returned.
*/
static std::optional<Conversation> FromJSON(const std::string& json_str, std::string* err);
/*! \brief Create the list of prompts from the messages based on the conversation template. */
Result<std::vector<Data>> AsPrompt();

/*! \brief Create a Conversation instance from the given JSON object. */
static Result<Conversation> FromJSON(const picojson::object& json);
/*! \brief Parse and create a Conversation instance from the given JSON string. */
static Result<Conversation> FromJSON(const std::string& json_str);
};

} // namespace json_ffi
Expand Down
38 changes: 20 additions & 18 deletions cpp/json_ffi/json_ffi_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ void JSONFFIEngine::StreamBackError(std::string request_id) {
response.model = "json_ffi"; // TODO: Return model name from engine (or from args)
response.system_fingerprint = "";

this->request_stream_callback_(Array<String>{picojson::value(response.ToJSON()).serialize()});
this->request_stream_callback_(Array<String>{picojson::value(response.AsJSON()).serialize()});
}

bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) {
std::optional<ChatCompletionRequest> optional_request =
ChatCompletionRequest::FromJSON(request_json_str, &err_);
if (!optional_request.has_value()) {
Result<ChatCompletionRequest> request_res = ChatCompletionRequest::FromJSON(request_json_str);
if (request_res.IsErr()) {
err_ = request_res.UnwrapErr();
return false;
}
ChatCompletionRequest request = optional_request.value();
ChatCompletionRequest request = request_res.Unwrap();
// Create Request
// TODO: Check if request_id is present already

Expand All @@ -74,17 +74,20 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
conv_template.messages = messages;

// check function calling
bool success_check = request.CheckFunctionCalling(conv_template, &err_);
if (!success_check) {
Result<Conversation> updated_conv_template = request.CheckFunctionCalling(conv_template);
if (updated_conv_template.IsErr()) {
err_ = updated_conv_template.UnwrapErr();
return false;
}
conv_template = updated_conv_template.Unwrap();

// get prompt
std::optional<Array<Data>> inputs_obj = conv_template.AsPrompt(&err_);
if (!inputs_obj.has_value()) {
Result<std::vector<Data>> inputs_obj = conv_template.AsPrompt();
if (inputs_obj.IsErr()) {
err_ = inputs_obj.UnwrapErr();
return false;
}
Array<Data> inputs = inputs_obj.value();
Array<Data> inputs = inputs_obj.Unwrap();

// generation_cfg
Array<String> stop_strs;
Expand Down Expand Up @@ -162,18 +165,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
this->engine_->Reload(engine_config_json_str);
this->default_generation_cfg_json_str_ = this->engine_->GetDefaultGenerationConfigJSONString();
picojson::object engine_config_json =
json::ParseToJsonObject(this->engine_->GetCompleteEngineConfigJSONString());
json::ParseToJSONObject(this->engine_->GetCompleteEngineConfigJSONString());

// Load conversation template.
Result<picojson::object> model_config_json =
serve::Model::LoadModelConfig(json::Lookup<std::string>(engine_config_json, "model"));
CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr();
std::optional<Conversation> conv_template = Conversation::FromJSON(
json::Lookup<picojson::object>(model_config_json.Unwrap(), "conv_template"), &err_);
if (!conv_template.has_value()) {
LOG(FATAL) << "Invalid conversation template JSON: " << err_;
}
this->conv_template_ = conv_template.value();
Result<Conversation> conv_template = Conversation::FromJSON(
json::Lookup<picojson::object>(model_config_json.Unwrap(), "conv_template"));
CHECK(!conv_template.IsErr()) << "Invalid conversation template JSON: "
<< conv_template.UnwrapErr();
this->conv_template_ = conv_template.Unwrap();
// Create streamer.
// Todo(mlc-team): Create one streamer for each request, instead of a global one.
this->streamer_ =
Expand Down Expand Up @@ -240,7 +242,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
response.choices = choices;
response.model = "json_ffi"; // TODO: Return model name from engine (or from args)
response.system_fingerprint = "";
response_arr.push_back(picojson::value(response.ToJSON()).serialize());
response_arr.push_back(picojson::value(response.AsJSON()).serialize());
}
return response_arr;
}
Expand Down
Loading