diff --git a/cpp/json_ffi/conv_template.cc b/cpp/json_ffi/conv_template.cc index 9511bb5b64..4feee6f98e 100644 --- a/cpp/json_ffi/conv_template.cc +++ b/cpp/json_ffi/conv_template.cc @@ -34,14 +34,8 @@ Conversation::Conversation() {"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]}, {"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {} -std::vector Conversation::CheckMessageSeps(std::vector& seps) { - if (seps.size() == 0 || seps.size() > 2) { - throw std::invalid_argument("seps should have size 1 or 2."); - } - return seps; -} - -std::optional> Conversation::AsPrompt(std::string* err) { +Result> Conversation::AsPrompt() { + using TResult = Result>; // Get the system message std::string system_msg = system_template; size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]); @@ -64,11 +58,11 @@ std::optional> Conversation::AsPrompt(std::string* err) { for (int i = 0; i < messages.size(); i++) { std::string role = messages[i].role; + // Todo(mlc-team): support content to be a single string. std::optional>> content = messages[i].content; if (roles.find(role) == roles.end()) { - *err += "\nRole " + role + " is not supported. "; - return std::nullopt; + return TResult::Error("Role \"" + role + "\" is not supported"); } std::string separator = separators[role == "assistant"]; // check assistant role @@ -90,29 +84,30 @@ std::optional> Conversation::AsPrompt(std::string* err) { message += role_prefix; - for (auto& item : content.value()) { - if (item.find("type") == item.end()) { - *err += "Content item should have a type field"; - return std::nullopt; + for (const auto& item : content.value()) { + auto it_type = item.find("type"); + if (it_type == item.end()) { + return TResult::Error("The content of a message does not have \"type\" field"); } - if (item["type"] == "text") { - if (item.find("text") == item.end()) { - *err += "Content item should have a text field"; - return std::nullopt; + if (it_type->second == "text") { + auto it_text = item.find("text"); + if (it_text == item.end()) { + return TResult::Error("The text type content of a message does not have \"text\" field"); } // replace placeholder[ROLE] with input message from role std::string role_text = role_templates[role]; std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)]; size_t pos = role_text.find(placeholder); if (pos != std::string::npos) { - role_text.replace(pos, placeholder.length(), item["text"]); + role_text.replace(pos, placeholder.length(), it_text->second); } - if (use_function_calling.has_value() && use_function_calling.value()) { + if (use_function_calling) { // replace placeholder[FUNCTION] with function_string // this assumes function calling is used for a single request scenario only if (!function_string.has_value()) { - *err += "Function string is required for function calling"; - return std::nullopt; + return TResult::Error( + "The function string in conversation template is not defined for function " + "calling."); } pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]); if (pos != std::string::npos) { @@ -122,8 +117,7 @@ std::optional> Conversation::AsPrompt(std::string* err) { } message += role_text; } else { - *err += "Unsupported content type: " + item["type"]; - return std::nullopt; + return TResult::Error("Unsupported content type: " + it_type->second); } } @@ -131,186 +125,201 @@ std::optional> Conversation::AsPrompt(std::string* err) { message_list.push_back(TextData(message)); } - return message_list; + return TResult::Ok(message_list); } -std::optional Conversation::FromJSON(const picojson::object& json, std::string* err) { +Result Conversation::FromJSON(const picojson::object& json_obj) { + using TResult = Result; Conversation conv; - // name - std::string name; - if (json::ParseJSONField(json, "name", name, err, false)) { - conv.name = name; + Result> name_res = + json::LookupOptionalWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } + conv.name = name_res.Unwrap(); - std::string system_template; - if (!json::ParseJSONField(json, "system_template", system_template, err, true)) { - return std::nullopt; + Result system_template_res = + json::LookupWithResultReturn(json_obj, "system_template"); + if (system_template_res.IsErr()) { + return TResult::Error(system_template_res.UnwrapErr()); } - conv.system_template = system_template; + conv.system_template = system_template_res.Unwrap(); - std::string system_message; - if (!json::ParseJSONField(json, "system_message", system_message, err, true)) { - return std::nullopt; + Result system_message_res = + json::LookupWithResultReturn(json_obj, "system_message"); + if (system_message_res.IsErr()) { + return TResult::Error(system_message_res.UnwrapErr()); } - conv.system_message = system_message; + conv.system_message = system_message_res.Unwrap(); - picojson::array system_prefix_token_ids_arr; - if (json::ParseJSONField(json, "system_prefix_token_ids", system_prefix_token_ids_arr, err, - false)) { + Result> system_prefix_token_ids_arr_res = + json::LookupOptionalWithResultReturn(json_obj, "system_prefix_token_ids"); + if (system_prefix_token_ids_arr_res.IsErr()) { + return TResult::Error(system_prefix_token_ids_arr_res.UnwrapErr()); + } + std::optional system_prefix_token_ids_arr = + system_prefix_token_ids_arr_res.Unwrap(); + if (system_prefix_token_ids_arr.has_value()) { std::vector system_prefix_token_ids; - for (const auto& token_id : system_prefix_token_ids_arr) { + system_prefix_token_ids.reserve(system_prefix_token_ids_arr.value().size()); + for (const auto& token_id : system_prefix_token_ids_arr.value()) { if (!token_id.is()) { - *err += "system_prefix_token_ids should be an array of integers."; - return std::nullopt; + return TResult::Error("A system prefix token id is not integer."); } system_prefix_token_ids.push_back(token_id.get()); } - conv.system_prefix_token_ids = system_prefix_token_ids; + conv.system_prefix_token_ids = std::move(system_prefix_token_ids); } - bool add_role_after_system_message; - if (!json::ParseJSONField(json, "add_role_after_system_message", add_role_after_system_message, - err, true)) { - return std::nullopt; + Result add_role_after_system_message_res = + json::LookupWithResultReturn(json_obj, "add_role_after_system_message"); + if (add_role_after_system_message_res.IsErr()) { + return TResult::Error(add_role_after_system_message_res.UnwrapErr()); } - conv.add_role_after_system_message = add_role_after_system_message; + conv.add_role_after_system_message = add_role_after_system_message_res.Unwrap(); - picojson::object roles_object; - if (!json::ParseJSONField(json, "roles", roles_object, err, true)) { - return std::nullopt; + Result roles_object_res = + json::LookupWithResultReturn(json_obj, "roles"); + if (roles_object_res.IsErr()) { + return TResult::Error(roles_object_res.UnwrapErr()); } - std::unordered_map roles; - for (const auto& role : roles_object) { + for (const auto& role : roles_object_res.Unwrap()) { if (!role.second.is()) { - *err += "roles should be a map of string to string."; - return std::nullopt; + return TResult::Error("A role value in the conversation template is not a string."); } - roles[role.first] = role.second.get(); + conv.roles[role.first] = role.second.get(); } - conv.roles = roles; - - picojson::object role_templates_object; - if (json::ParseJSONField(json, "role_templates", role_templates_object, err, false)) { - for (const auto& role : role_templates_object) { - if (!role.second.is()) { - *err += "role_templates should be a map of string to string."; - return std::nullopt; + + Result> role_templates_object_res = + json::LookupOptionalWithResultReturn(json_obj, "role_templates"); + if (role_templates_object_res.IsErr()) { + return TResult::Error(role_templates_object_res.UnwrapErr()); + } + std::optional role_templates_object = role_templates_object_res.Unwrap(); + if (role_templates_object.has_value()) { + for (const auto& [role, msg] : role_templates_object.value()) { + if (!msg.is()) { + return TResult::Error("A value in \"role_templates\" is not a string."); } - conv.role_templates[role.first] = role.second.get(); + conv.role_templates[role] = msg.get(); } } - picojson::array messages_arr; - if (!json::ParseJSONField(json, "messages", messages_arr, err, true)) { - return std::nullopt; + Result messages_arr_res = + json::LookupWithResultReturn(json_obj, "messages"); + if (messages_arr_res.IsErr()) { + return TResult::Error(messages_arr_res.UnwrapErr()); } - std::vector messages; - for (const auto& message : messages_arr) { + for (const auto& message : messages_arr_res.Unwrap()) { if (!message.is()) { - *err += "messages should be an array of objects."; - return std::nullopt; + return TResult::Error("A message in the conversation template is not a JSON object."); } picojson::object message_obj = message.get(); - std::string role; - if (!json::ParseJSONField(message_obj, "role", role, err, true)) { - *err += "role field is required in messages."; - return std::nullopt; + Result role_res = json::LookupWithResultReturn(message_obj, "role"); + if (role_res.IsErr()) { + return TResult::Error(role_res.UnwrapErr()); + } + Result> content_arr_res = + json::LookupOptionalWithResultReturn(message_obj, "content"); + if (content_arr_res.IsErr()) { + return TResult::Error(content_arr_res.UnwrapErr()); } - picojson::array content_arr; + std::optional content_arr = content_arr_res.Unwrap(); std::vector> content; - if (json::ParseJSONField(message_obj, "content", content_arr, err, false)) { - for (const auto& item : content_arr) { + if (content_arr.has_value()) { + content.reserve(content_arr.value().size()); + for (const auto& item : content_arr.value()) { + // Todo(mlc-team): allow content item to be a single string. if (!item.is()) { - *err += "Content item is not an object"; - return std::nullopt; + return TResult::Error("The content of conversation template message is not an object"); } std::unordered_map item_map; - picojson::object item_obj = item.get(); - for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); - ++i) { - item_map[i->first] = i->second.to_str(); + for (const auto& [key, value] : item.get()) { + item_map[key] = value.to_str(); } - content.push_back(item_map); + content.push_back(std::move(item_map)); } } - messages.push_back({role, content}); + conv.messages.push_back({role_res.Unwrap(), content}); } - conv.messages = messages; - picojson::array seps_arr; - if (!json::ParseJSONField(json, "seps", seps_arr, err, true)) { - return std::nullopt; + Result seps_arr_res = + json::LookupWithResultReturn(json_obj, "seps"); + if (seps_arr_res.IsErr()) { + return TResult::Error(seps_arr_res.UnwrapErr()); } std::vector seps; - for (const auto& sep : seps_arr) { + for (const auto& sep : seps_arr_res.Unwrap()) { if (!sep.is()) { - *err += "seps should be an array of strings."; - return std::nullopt; + return TResult::Error("A separator (\"seps\") of the conversation template is not a string"); } - seps.push_back(sep.get()); + conv.seps.push_back(sep.get()); } - conv.seps = seps; - std::string role_content_sep; - if (!json::ParseJSONField(json, "role_content_sep", role_content_sep, err, true)) { - return std::nullopt; + Result role_content_sep_res = + json::LookupWithResultReturn(json_obj, "role_content_sep"); + if (role_content_sep_res.IsErr()) { + return TResult::Error(role_content_sep_res.UnwrapErr()); } - conv.role_content_sep = role_content_sep; + conv.role_content_sep = role_content_sep_res.Unwrap(); - std::string role_empty_sep; - if (!json::ParseJSONField(json, "role_empty_sep", role_empty_sep, err, true)) { - return std::nullopt; + Result role_empty_sep_res = + json::LookupWithResultReturn(json_obj, "role_empty_sep"); + if (role_empty_sep_res.IsErr()) { + return TResult::Error(role_empty_sep_res.UnwrapErr()); } - conv.role_empty_sep = role_empty_sep; + conv.role_empty_sep = role_empty_sep_res.Unwrap(); - picojson::array stop_str_arr; - if (!json::ParseJSONField(json, "stop_str", stop_str_arr, err, true)) { - return std::nullopt; + Result stop_str_arr_res = + json::LookupWithResultReturn(json_obj, "stop_str"); + if (stop_str_arr_res.IsErr()) { + return TResult::Error(stop_str_arr_res.UnwrapErr()); } - std::vector stop_str; - for (const auto& stop : stop_str_arr) { + for (const auto& stop : stop_str_arr_res.Unwrap()) { if (!stop.is()) { - *err += "stop_str should be an array of strings."; - return std::nullopt; + return TResult::Error( + "A stop string (\"stop_str\") of the conversation template is not a string."); } - stop_str.push_back(stop.get()); + conv.stop_str.push_back(stop.get()); } - conv.stop_str = stop_str; - picojson::array stop_token_ids_arr; - if (!json::ParseJSONField(json, "stop_token_ids", stop_token_ids_arr, err, true)) { - return std::nullopt; + Result stop_token_ids_arr_res = + json::LookupWithResultReturn(json_obj, "stop_token_ids"); + if (stop_token_ids_arr_res.IsErr()) { + return TResult::Error(stop_token_ids_arr_res.UnwrapErr()); } - std::vector stop_token_ids; - for (const auto& stop : stop_token_ids_arr) { + for (const auto& stop : stop_token_ids_arr_res.Unwrap()) { if (!stop.is()) { - *err += "stop_token_ids should be an array of integers."; - return std::nullopt; + return TResult::Error( + "A stop token id (\"stop_token_ids\") of the conversation template is not an integer."); } - stop_token_ids.push_back(stop.get()); + conv.stop_token_ids.push_back(stop.get()); } - conv.stop_token_ids = stop_token_ids; - std::string function_string; - if (!json::ParseJSONField(json, "function_string", function_string, err, false)) { - conv.function_string = function_string; + Result> function_string_res = + json::LookupOptionalWithResultReturn(json_obj, "function_string"); + if (function_string_res.IsErr()) { + return TResult::Error(function_string_res.UnwrapErr()); } + conv.function_string = function_string_res.Unwrap(); - bool use_function_calling; - if (json::ParseJSONField(json, "use_function_calling", use_function_calling, err, false)) { - conv.use_function_calling = use_function_calling; + Result use_function_calling_res = json::LookupOrDefaultWithResultReturn( + json_obj, "use_function_calling", conv.use_function_calling); + if (use_function_calling_res.IsErr()) { + return TResult::Error(use_function_calling_res.UnwrapErr()); } + conv.use_function_calling = use_function_calling_res.Unwrap(); - return conv; + return TResult::Ok(conv); } -std::optional Conversation::FromJSON(const std::string& json_str, std::string* err) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!json_obj.has_value()) { - return std::nullopt; +Result Conversation::FromJSON(const std::string& json_str) { + Result json_obj = json::ParseToJSONObjectWithResultReturn(json_str); + if (json_obj.IsErr()) { + return Result::Error(json_obj.UnwrapErr()); } - return Conversation::FromJSON(json_obj.value(), err); + return Conversation::FromJSON(json_obj.Unwrap()); } } // namespace json_ffi diff --git a/cpp/json_ffi/conv_template.h b/cpp/json_ffi/conv_template.h index eeb348831c..2d579a8d94 100644 --- a/cpp/json_ffi/conv_template.h +++ b/cpp/json_ffi/conv_template.h @@ -10,6 +10,7 @@ #include #include "../serve/data.h" +#include "../support/result.h" #include "picojson.h" using namespace mlc::llm::serve; @@ -86,34 +87,17 @@ struct Conversation { // Function call fields // whether using function calling or not, helps check for output message format in API call std::optional function_string = std::nullopt; - std::optional use_function_calling = false; + bool use_function_calling = false; Conversation(); - /** - * @brief Checks the size of the separators vector. - * This function checks if the size of the separators vector is either 1 or 2. - * If the size is not 1 or 2, it throws an invalid_argument exception. - */ - static std::vector CheckMessageSeps(std::vector& seps); - - /*! - * \brief Create the list of prompts from the messages based on the conversation template. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - std::optional> AsPrompt(std::string* err); - - /*! - * \brief Create a Conversation instance from the given JSON object. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const picojson::object& json, std::string* err); - - /*! - * \brief Parse and create a Conversation instance from the given JSON string. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const std::string& json_str, std::string* err); + /*! \brief Create the list of prompts from the messages based on the conversation template. */ + Result> AsPrompt(); + + /*! \brief Create a Conversation instance from the given JSON object. */ + static Result FromJSON(const picojson::object& json); + /*! \brief Parse and create a Conversation instance from the given JSON string. */ + static Result FromJSON(const std::string& json_str); }; } // namespace json_ffi diff --git a/cpp/json_ffi/json_ffi_engine.cc b/cpp/json_ffi/json_ffi_engine.cc index 6b2676ee3f..b4f9751719 100644 --- a/cpp/json_ffi/json_ffi_engine.cc +++ b/cpp/json_ffi/json_ffi_engine.cc @@ -41,16 +41,16 @@ void JSONFFIEngine::StreamBackError(std::string request_id) { response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - this->request_stream_callback_(Array{picojson::value(response.ToJSON()).serialize()}); + this->request_stream_callback_(Array{picojson::value(response.AsJSON()).serialize()}); } bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) { - std::optional optional_request = - ChatCompletionRequest::FromJSON(request_json_str, &err_); - if (!optional_request.has_value()) { + Result request_res = ChatCompletionRequest::FromJSON(request_json_str); + if (request_res.IsErr()) { + err_ = request_res.UnwrapErr(); return false; } - ChatCompletionRequest request = optional_request.value(); + ChatCompletionRequest request = request_res.Unwrap(); // Create Request // TODO: Check if request_id is present already @@ -74,17 +74,20 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request conv_template.messages = messages; // check function calling - bool success_check = request.CheckFunctionCalling(conv_template, &err_); - if (!success_check) { + Result updated_conv_template = request.CheckFunctionCalling(conv_template); + if (updated_conv_template.IsErr()) { + err_ = updated_conv_template.UnwrapErr(); return false; } + conv_template = updated_conv_template.Unwrap(); // get prompt - std::optional> inputs_obj = conv_template.AsPrompt(&err_); - if (!inputs_obj.has_value()) { + Result> inputs_obj = conv_template.AsPrompt(); + if (inputs_obj.IsErr()) { + err_ = inputs_obj.UnwrapErr(); return false; } - Array inputs = inputs_obj.value(); + Array inputs = inputs_obj.Unwrap(); // generation_cfg Array stop_strs; @@ -162,18 +165,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { this->engine_->Reload(engine_config_json_str); this->default_generation_cfg_json_str_ = this->engine_->GetDefaultGenerationConfigJSONString(); picojson::object engine_config_json = - json::ParseToJsonObject(this->engine_->GetCompleteEngineConfigJSONString()); + json::ParseToJSONObject(this->engine_->GetCompleteEngineConfigJSONString()); // Load conversation template. Result model_config_json = serve::Model::LoadModelConfig(json::Lookup(engine_config_json, "model")); CHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr(); - std::optional conv_template = Conversation::FromJSON( - json::Lookup(model_config_json.Unwrap(), "conv_template"), &err_); - if (!conv_template.has_value()) { - LOG(FATAL) << "Invalid conversation template JSON: " << err_; - } - this->conv_template_ = conv_template.value(); + Result conv_template = Conversation::FromJSON( + json::Lookup(model_config_json.Unwrap(), "conv_template")); + CHECK(!conv_template.IsErr()) << "Invalid conversation template JSON: " + << conv_template.UnwrapErr(); + this->conv_template_ = conv_template.Unwrap(); // Create streamer. // Todo(mlc-team): Create one streamer for each request, instead of a global one. this->streamer_ = @@ -240,7 +242,7 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode { response.choices = choices; response.model = "json_ffi"; // TODO: Return model name from engine (or from args) response.system_fingerprint = ""; - response_arr.push_back(picojson::value(response.ToJSON()).serialize()); + response_arr.push_back(picojson::value(response.AsJSON()).serialize()); } return response_arr; } diff --git a/cpp/json_ffi/openai_api_protocol.cc b/cpp/json_ffi/openai_api_protocol.cc index 4547108eb5..c07de8fef5 100644 --- a/cpp/json_ffi/openai_api_protocol.cc +++ b/cpp/json_ffi/openai_api_protocol.cc @@ -11,53 +11,41 @@ namespace mlc { namespace llm { namespace json_ffi { -std::string generate_uuid_string(size_t length) { - auto randchar = []() -> char { - const char charset[] = - "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz"; - const size_t max_index = (sizeof(charset) - 1); - return charset[rand() % max_index]; - }; - std::string str(length, 0); - std::generate_n(str.begin(), length, randchar); - return str; -} - -std::optional ChatFunction::FromJSON(const picojson::object& json_obj, - std::string* err) { - ChatFunction chatFunc; +Result ChatFunction::FromJSON(const picojson::object& json_obj) { + using TResult = Result; + ChatFunction chat_func; - // description (optional) - std::string description; - if (json::ParseJSONField(json_obj, "description", description, err, false)) { - chatFunc.description = description; + // description + Result> description_res = + json::LookupOptionalWithResultReturn(json_obj, "description"); + if (description_res.IsErr()) { + return TResult::Error(description_res.UnwrapErr()); } + chat_func.description = description_res.Unwrap(); // name - std::string name; - if (!json::ParseJSONField(json_obj, "name", name, err, true)) { - return std::nullopt; + Result name_res = json::LookupWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } - chatFunc.name = name; + chat_func.name = name_res.Unwrap(); // parameters - picojson::object parameters_obj; - if (!json::ParseJSONField(json_obj, "parameters", parameters_obj, err, true)) { - return std::nullopt; + Result parameters_obj_res = + json::LookupWithResultReturn(json_obj, "parameters"); + if (parameters_obj_res.IsErr()) { + return TResult::Error(parameters_obj_res.UnwrapErr()); } - std::unordered_map parameters; - for (picojson::value::object::const_iterator i = parameters_obj.begin(); - i != parameters_obj.end(); ++i) { - parameters[i->first] = i->second.to_str(); + picojson::object parameters_obj = parameters_obj_res.Unwrap(); + chat_func.parameters.reserve(parameters_obj.size()); + for (const auto& [key, value] : parameters_obj) { + chat_func.parameters[key] = value.to_str(); } - chatFunc.parameters = parameters; - return chatFunc; + return TResult::Ok(chat_func); } -picojson::object ChatFunction::ToJSON() const { +picojson::object ChatFunction::AsJSON() const { picojson::object obj; if (this->description.has_value()) { obj["description"] = picojson::value(this->description.value()); @@ -71,57 +59,63 @@ picojson::object ChatFunction::ToJSON() const { return obj; } -std::optional ChatTool::FromJSON(const picojson::object& json_obj, std::string* err) { +Result ChatTool::FromJSON(const picojson::object& json_obj) { + using TResult = Result; ChatTool chatTool; // function - picojson::object function_obj; - if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { - return std::nullopt; + Result function_obj_res = + json::LookupWithResultReturn(json_obj, "function"); + if (function_obj_res.IsErr()) { + return TResult::Error(function_obj_res.UnwrapErr()); } - - std::optional function = ChatFunction::FromJSON(function_obj, err); - if (!function.has_value()) { - return std::nullopt; + Result function = ChatFunction::FromJSON(function_obj_res.Unwrap()); + if (function.IsErr()) { + return TResult::Error(function.UnwrapErr()); } - chatTool.function = function.value(); + chatTool.function = function.Unwrap(); - return chatTool; + return TResult::Ok(chatTool); } -picojson::object ChatTool::ToJSON() const { +picojson::object ChatTool::AsJSON() const { picojson::object obj; obj["type"] = picojson::value("function"); - obj["function"] = picojson::value(this->function.ToJSON()); + obj["function"] = picojson::value(this->function.AsJSON()); return obj; } -std::optional ChatFunctionCall::FromJSON(const picojson::object& json_obj, - std::string* err) { - ChatFunctionCall chatFuncCall; +Result ChatFunctionCall::FromJSON(const picojson::object& json_obj) { + using TResult = Result; + ChatFunctionCall chat_func_call; // name - std::string name; - if (!json::ParseJSONField(json_obj, "name", name, err, true)) { - return std::nullopt; + Result name_res = json::LookupWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } - chatFuncCall.name = name; + chat_func_call.name = name_res.Unwrap(); // arguments - picojson::object arguments_obj; - if (json::ParseJSONField(json_obj, "arguments", arguments_obj, err, false)) { + Result> arguments_obj_res = + json::LookupOptionalWithResultReturn(json_obj, "arguments"); + if (arguments_obj_res.IsErr()) { + return TResult::Error(arguments_obj_res.UnwrapErr()); + } + std::optional arguments_obj = arguments_obj_res.Unwrap(); + if (arguments_obj.has_value()) { std::unordered_map arguments; - for (picojson::value::object::const_iterator i = arguments_obj.begin(); - i != arguments_obj.end(); ++i) { - arguments[i->first] = i->second.to_str(); + arguments.reserve(arguments_obj.value().size()); + for (const auto& [key, value] : arguments_obj.value()) { + arguments[key] = value.to_str(); } - chatFuncCall.arguments = arguments; + chat_func_call.arguments = std::move(arguments); } - return chatFuncCall; + return TResult::Ok(chat_func_call); } -picojson::object ChatFunctionCall::ToJSON() const { +picojson::object ChatFunctionCall::AsJSON() const { picojson::object obj; picojson::object arguments_obj; if (this->arguments.has_value()) { @@ -135,69 +129,75 @@ picojson::object ChatFunctionCall::ToJSON() const { return obj; } -std::optional ChatToolCall::FromJSON(const picojson::object& json_obj, - std::string* err) { - ChatToolCall chatToolCall; +Result ChatToolCall::FromJSON(const picojson::object& json_obj) { + using TResult = Result; + ChatToolCall chat_tool_call; // function - picojson::object function_obj; - if (!json::ParseJSONField(json_obj, "function", function_obj, err, true)) { - return std::nullopt; + Result function_obj_res = + json::LookupWithResultReturn(json_obj, "function"); + if (function_obj_res.IsErr()) { + return TResult::Error(function_obj_res.UnwrapErr()); } - - std::optional function = ChatFunctionCall::FromJSON(function_obj, err); - if (!function.has_value()) { - return std::nullopt; - }; - chatToolCall.function = function.value(); + Result function_res = ChatFunctionCall::FromJSON(function_obj_res.Unwrap()); + if (function_res.IsErr()) { + return TResult::Error(function_res.UnwrapErr()); + } + chat_tool_call.function = function_res.Unwrap(); // overwrite default id - std::string id; - if (!json::ParseJSONField(json_obj, "id", id, err, false)) { - return std::nullopt; + Result> id_res = + json::LookupOptionalWithResultReturn(json_obj, "id"); + if (id_res.IsErr()) { + return TResult::Error(id_res.UnwrapErr()); + } + std::optional id = id_res.UnwrapErr(); + if (id.has_value()) { + chat_tool_call.id = id.value(); } - chatToolCall.id = id; - return chatToolCall; + return TResult::Ok(chat_tool_call); } -picojson::object ChatToolCall::ToJSON() const { +picojson::object ChatToolCall::AsJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); - obj["function"] = picojson::value(this->function.ToJSON()); + obj["function"] = picojson::value(this->function.AsJSON()); obj["type"] = picojson::value("function"); return obj; } -std::optional ChatCompletionMessage::FromJSON( - const picojson::object& json_obj, std::string* err) { +Result ChatCompletionMessage::FromJSON(const picojson::object& json_obj) { + using TResult = Result; ChatCompletionMessage message; // content - picojson::array content_arr; - if (!json::ParseJSONField(json_obj, "content", content_arr, err, true)) { - return std::nullopt; - } - std::vector > content; - for (const auto& item : content_arr) { + Result content_arr_res = + json::LookupWithResultReturn(json_obj, "content"); + if (content_arr_res.IsErr()) { + return TResult::Error(content_arr_res.UnwrapErr()); + } + std::vector> content; + for (const auto& item : content_arr_res.Unwrap()) { + // Todo(mlc-team): allow content item to be a single string. if (!item.is()) { - *err += "Content item is not an object"; - return std::nullopt; + return TResult::Error("The content of chat completion message is not an object"); } - std::unordered_map item_map; picojson::object item_obj = item.get(); - for (picojson::value::object::const_iterator i = item_obj.begin(); i != item_obj.end(); ++i) { - item_map[i->first] = i->second.to_str(); + std::unordered_map item_map; + for (const auto& [key, value] : item_obj) { + item_map[key] = value.to_str(); } - content.push_back(item_map); + content.push_back(std::move(item_map)); } message.content = content; // role - std::string role_str; - if (!json::ParseJSONField(json_obj, "role", role_str, err, true)) { - return std::nullopt; + Result role_str_res = json::LookupWithResultReturn(json_obj, "role"); + if (role_str_res.IsErr()) { + return TResult::Error(role_str_res.UnwrapErr()); } + std::string role_str = role_str_res.Unwrap(); if (role_str == "system") { message.role = Role::system; } else if (role_str == "user") { @@ -207,124 +207,148 @@ std::optional ChatCompletionMessage::FromJSON( } else if (role_str == "tool") { message.role = Role::tool; } else { - *err += "Invalid role"; - return std::nullopt; + return TResult::Error("Invalid role in chat completion message: " + role_str); } // name - std::string name; - if (json::ParseJSONField(json_obj, "name", name, err, false)) { - message.name = name; + Result> name_res = + json::LookupOptionalWithResultReturn(json_obj, "name"); + if (name_res.IsErr()) { + return TResult::Error(name_res.UnwrapErr()); } + message.name = name_res.Unwrap(); // tool calls - picojson::array tool_calls_arr; - if (json::ParseJSONField(json_obj, "tool_calls", tool_calls_arr, err, false)) { + Result> tool_calls_arr_res = + json::LookupOptionalWithResultReturn(json_obj, "tool_calls"); + if (tool_calls_arr_res.IsErr()) { + return TResult::Error(tool_calls_arr_res.UnwrapErr()); + } + std::optional tool_calls_arr = tool_calls_arr_res.Unwrap(); + if (tool_calls_arr.has_value()) { std::vector tool_calls; - for (const auto& item : tool_calls_arr) { + tool_calls.reserve(tool_calls_arr.value().size()); + for (const auto& item : tool_calls_arr.value()) { if (!item.is()) { - *err += "Chat Tool Call item is not an object"; - return std::nullopt; + return TResult::Error("A tool call item in the chat completion message is not an object"); + } + Result tool_call = ChatToolCall::FromJSON(item.get()); + if (tool_call.IsErr()) { + return TResult::Error(tool_call.UnwrapErr()); } - picojson::object item_obj = item.get(); - std::optional tool_call = ChatToolCall::FromJSON(item_obj, err); - if (!tool_call.has_value()) { - return std::nullopt; - }; - tool_calls.push_back(tool_call.value()); + tool_calls.push_back(tool_call.Unwrap()); } message.tool_calls = tool_calls; } // tool call id - std::string tool_call_id; - if (json::ParseJSONField(json_obj, "tool_call_id", tool_call_id, err, false)) { - message.tool_call_id = tool_call_id; + Result> tool_call_id_res = + json::LookupOptionalWithResultReturn(json_obj, "tool_call_id"); + if (tool_call_id_res.IsErr()) { + return TResult::Error(tool_call_id_res.UnwrapErr()); } + message.tool_call_id = tool_call_id_res.Unwrap(); - return message; + return TResult::Ok(message); } -std::optional ChatCompletionRequest::FromJSON( - const picojson::object& json_obj, std::string* err) { +Result ChatCompletionRequest::FromJSON(const std::string& json_str) { + using TResult = Result; + Result json_obj_res = json::ParseToJSONObjectWithResultReturn(json_str); + if (json_obj_res.IsErr()) { + return TResult::Error(json_obj_res.UnwrapErr()); + } + picojson::object json_obj = json_obj_res.Unwrap(); ChatCompletionRequest request; // messages - picojson::array messages_arr; - if (!json::ParseJSONField(json_obj, "messages", messages_arr, err, true)) { - return std::nullopt; + Result messages_arr_res = + json::LookupWithResultReturn(json_obj, "messages"); + if (messages_arr_res.IsErr()) { + return TResult::Error(messages_arr_res.UnwrapErr()); } std::vector messages; - for (const auto& item : messages_arr) { + for (const auto& item : messages_arr_res.Unwrap()) { + if (!item.is()) { + return TResult::Error("A message in chat completion request is not object"); + } picojson::object item_obj = item.get(); - std::optional message = ChatCompletionMessage::FromJSON(item_obj, err); - if (!message.has_value()) { - return std::nullopt; + Result message = ChatCompletionMessage::FromJSON(item_obj); + if (message.IsErr()) { + return TResult::Error(message.UnwrapErr()); } - messages.push_back(message.value()); + messages.push_back(message.Unwrap()); } request.messages = messages; // model - std::string model; - if (!json::ParseJSONField(json_obj, "model", model, err, true)) { - return std::nullopt; + Result model_res = json::LookupWithResultReturn(json_obj, "model"); + if (model_res.IsErr()) { + return TResult::Error(model_res.UnwrapErr()); + } + request.model = model_res.Unwrap(); + + // max_tokens + Result> max_tokens_res = + json::LookupOptionalWithResultReturn(json_obj, "max_tokens"); + if (max_tokens_res.IsErr()) { + return TResult::Error(max_tokens_res.UnwrapErr()); } - request.model = model; + request.max_tokens = max_tokens_res.Unwrap(); // frequency_penalty - double frequency_penalty; - if (json::ParseJSONField(json_obj, "frequency_penalty", frequency_penalty, err, false)) { - request.frequency_penalty = frequency_penalty; + Result> frequency_penalty_res = + json::LookupOptionalWithResultReturn(json_obj, "frequency_penalty"); + if (frequency_penalty_res.IsErr()) { + return TResult::Error(frequency_penalty_res.UnwrapErr()); } + request.frequency_penalty = frequency_penalty_res.Unwrap(); // presence_penalty - double presence_penalty; - if (json::ParseJSONField(json_obj, "presence_penalty", presence_penalty, err, false)) { - request.presence_penalty = presence_penalty; + Result> presence_penalty_res = + json::LookupOptionalWithResultReturn(json_obj, "presence_penalty"); + if (presence_penalty_res.IsErr()) { + return TResult::Error(presence_penalty_res.UnwrapErr()); } + request.presence_penalty = presence_penalty_res.Unwrap(); // tool_choice - std::string tool_choice = "auto"; - request.tool_choice = tool_choice; - if (json::ParseJSONField(json_obj, "tool_choice", tool_choice, err, false)) { - request.tool_choice = tool_choice; + Result tool_choice_res = + json::LookupOrDefaultWithResultReturn(json_obj, "tool_choice", "auto"); + if (tool_choice_res.IsErr()) { + return TResult::Error(tool_choice_res.UnwrapErr()); } + request.tool_choice = tool_choice_res.Unwrap(); // tools - picojson::array tools_arr; - if (json::ParseJSONField(json_obj, "tools", tools_arr, err, false)) { + Result> tools_arr_res = + json::LookupOptionalWithResultReturn(json_obj, "tools"); + if (tool_choice_res.IsErr()) { + return TResult::Error(tool_choice_res.UnwrapErr()); + } + std::optional tools_arr = tools_arr_res.Unwrap(); + if (tools_arr.has_value()) { std::vector tools; - for (const auto& item : tools_arr) { + tools.reserve(tools_arr.value().size()); + for (const auto& item : tools_arr.value()) { if (!item.is()) { - *err += "Chat Tool item is not an object"; - return std::nullopt; + return TResult::Error("A tool of the chat completion request is not an object"); + } + Result tool = ChatTool::FromJSON(item.get()); + if (tool.IsErr()) { + return TResult::Error(tool.UnwrapErr()); } - picojson::object item_obj = item.get(); - std::optional tool = ChatTool::FromJSON(item_obj, err); - if (!tool.has_value()) { - return std::nullopt; - }; - tools.push_back(tool.value()); + tools.push_back(tool.Unwrap()); } request.tools = tools; } // TODO: Other parameters - return request; -} - -std::optional ChatCompletionRequest::FromJSON(const std::string& json_str, - std::string* err) { - std::optional json_obj = json::LoadJSONFromString(json_str, err); - if (!json_obj.has_value()) { - return std::nullopt; - } - return ChatCompletionRequest::FromJSON(json_obj.value(), err); + return TResult::Ok(request); } -picojson::object ChatCompletionMessage::ToJSON() const { +picojson::object ChatCompletionMessage::AsJSON() const { picojson::object obj; picojson::array content_arr; for (const auto& item : this->content.value()) { @@ -353,17 +377,18 @@ picojson::object ChatCompletionMessage::ToJSON() const { if (this->tool_calls.has_value()) { picojson::array tool_calls_arr; for (const auto& tool_call : this->tool_calls.value()) { - tool_calls_arr.push_back(picojson::value(tool_call.ToJSON())); + tool_calls_arr.push_back(picojson::value(tool_call.AsJSON())); } obj["tool_calls"] = picojson::value(tool_calls_arr); } return obj; } -bool ChatCompletionRequest::CheckFunctionCalling(Conversation& conv_template, std::string* err) { +Result ChatCompletionRequest::CheckFunctionCalling(Conversation conv_template) { + using TResult = Result; if (!tools.has_value() || (tool_choice.has_value() && tool_choice.value() == "none")) { conv_template.use_function_calling = false; - return true; + return TResult::Ok(conv_template); } std::vector tools_ = tools.value(); std::string tool_choice_ = tool_choice.value(); @@ -372,29 +397,28 @@ bool ChatCompletionRequest::CheckFunctionCalling(Conversation& conv_template, st for (const auto& tool : tools_) { if (tool.function.name == tool_choice_) { conv_template.use_function_calling = true; - picojson::value function_str(tool.function.ToJSON()); + picojson::value function_str(tool.function.AsJSON()); conv_template.function_string = function_str.serialize(); - return true; + return TResult::Ok(conv_template); } } if (tool_choice_ != "auto") { - *err += "Invalid tool_choice value: " + tool_choice_; - return false; + return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_); } picojson::array function_list; for (const auto& tool : tools_) { - function_list.push_back(picojson::value(tool.function.ToJSON())); + function_list.push_back(picojson::value(tool.function.AsJSON())); } conv_template.use_function_calling = true; picojson::value function_list_json(function_list); conv_template.function_string = function_list_json.serialize(); - return true; + return TResult::Ok(conv_template); }; -picojson::object ChatCompletionResponseChoice::ToJSON() const { +picojson::object ChatCompletionResponseChoice::AsJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -410,11 +434,11 @@ picojson::object ChatCompletionResponseChoice::ToJSON() const { } } obj["index"] = picojson::value((int64_t)this->index); - obj["message"] = picojson::value(this->message.ToJSON()); + obj["message"] = picojson::value(this->message.AsJSON()); return obj; } -picojson::object ChatCompletionStreamResponseChoice::ToJSON() const { +picojson::object ChatCompletionStreamResponseChoice::AsJSON() const { picojson::object obj; if (!this->finish_reason.has_value()) { obj["finish_reason"] = picojson::value(); @@ -431,16 +455,16 @@ picojson::object ChatCompletionStreamResponseChoice::ToJSON() const { } obj["index"] = picojson::value((int64_t)this->index); - obj["delta"] = picojson::value(this->delta.ToJSON()); + obj["delta"] = picojson::value(this->delta.AsJSON()); return obj; } -picojson::object ChatCompletionResponse::ToJSON() const { +picojson::object ChatCompletionResponse::AsJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; for (const auto& choice : this->choices) { - choices_arr.push_back(picojson::value(choice.ToJSON())); + choices_arr.push_back(picojson::value(choice.AsJSON())); } obj["choices"] = picojson::value(choices_arr); obj["created"] = picojson::value((int64_t)this->created); @@ -450,12 +474,12 @@ picojson::object ChatCompletionResponse::ToJSON() const { return obj; } -picojson::object ChatCompletionStreamResponse::ToJSON() const { +picojson::object ChatCompletionStreamResponse::AsJSON() const { picojson::object obj; obj["id"] = picojson::value(this->id); picojson::array choices_arr; for (const auto& choice : this->choices) { - choices_arr.push_back(picojson::value(choice.ToJSON())); + choices_arr.push_back(picojson::value(choice.AsJSON())); } obj["choices"] = picojson::value(choices_arr); obj["created"] = picojson::value((int64_t)this->created); diff --git a/cpp/json_ffi/openai_api_protocol.h b/cpp/json_ffi/openai_api_protocol.h index 70ef2fb22f..914366c2f1 100644 --- a/cpp/json_ffi/openai_api_protocol.h +++ b/cpp/json_ffi/openai_api_protocol.h @@ -13,6 +13,7 @@ #include #include +#include "../support/result.h" #include "conv_template.h" #include "picojson.h" @@ -24,17 +25,30 @@ enum class Role { system, user, assistant, tool }; enum class Type { text, json_object, function }; enum class FinishReason { stop, length, tool_calls, error }; -std::string generate_uuid_string(size_t length); +inline std::string generate_uuid_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} class ChatFunction { public: std::optional description = std::nullopt; std::string name; + // Todo: change to std::vector>? std::unordered_map parameters; // Assuming parameters are string key-value pairs - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatTool { @@ -42,8 +56,8 @@ class ChatTool { Type type = Type::function; ChatFunction function; - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatFunctionCall { @@ -52,8 +66,8 @@ class ChatFunctionCall { std::optional> arguments = std::nullopt; // Assuming arguments are string key-value pairs - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatToolCall { @@ -62,8 +76,8 @@ class ChatToolCall { Type type = Type::function; ChatFunctionCall function; - static std::optional FromJSON(const picojson::object& json, std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class ChatCompletionMessage { @@ -75,9 +89,8 @@ class ChatCompletionMessage { std::optional> tool_calls = std::nullopt; std::optional tool_call_id = std::nullopt; - static std::optional FromJSON(const picojson::object& json, - std::string* err); - picojson::object ToJSON() const; + static Result FromJSON(const picojson::object& json); + picojson::object AsJSON() const; }; class RequestResponseFormat { @@ -108,20 +121,10 @@ class ChatCompletionRequest { bool ignore_eos = false; // RequestResponseFormat response_format; //TODO: implement this - /*! - * \brief Create a ChatCompletionRequest instance from the given JSON object. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const picojson::object& json_obj, - std::string* err); - /*! - * \brief Parse and create a ChatCompletionRequest instance from the given JSON string. - * When creation fails, errors are dumped to the input error string, and nullopt is returned. - */ - static std::optional FromJSON(const std::string& json_str, - std::string* err); - - bool CheckFunctionCalling(Conversation& conv_template, std::string* err); + /*! \brief Parse and create a ChatCompletionRequest instance from the given JSON string. */ + static Result FromJSON(const std::string& json_str); + + Result CheckFunctionCalling(Conversation conv_template); // TODO: check_penalty_range, check_logit_bias, check_logprobs }; @@ -132,7 +135,7 @@ class ChatCompletionResponseChoice { ChatCompletionMessage message; // TODO: logprobs - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; class ChatCompletionStreamResponseChoice { @@ -142,7 +145,7 @@ class ChatCompletionStreamResponseChoice { ChatCompletionMessage delta; // TODO: logprobs - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; class ChatCompletionResponse { @@ -155,7 +158,7 @@ class ChatCompletionResponse { std::string object = "chat.completion"; // TODO: usage_info - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; class ChatCompletionStreamResponse { @@ -167,7 +170,7 @@ class ChatCompletionStreamResponse { std::string system_fingerprint; std::string object = "chat.completion.chunk"; - picojson::object ToJSON() const; + picojson::object AsJSON() const; }; } // namespace json_ffi diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 2daf1d0338..62ba2787b9 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -90,7 +90,7 @@ ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module, std::string json_str = ""; TypedPackedFunc pf = module.GetFunction("_metadata"); json_str = pf(); - picojson::object json = json::ParseToJsonObject(json_str); + picojson::object json = json::ParseToJSONObject(json_str); try { return ModelMetadata::FromJSON(json, model_config); } catch (const std::exception& e) { diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 30a3617a8d..9b9d5ba65a 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -79,7 +79,7 @@ GenerationConfig::GenerationConfig( GenerationConfig::GenerationConfig(String config_json_str, Optional default_config_json_str) { - picojson::object config = json::ParseToJsonObject(config_json_str); + picojson::object config = json::ParseToJSONObject(config_json_str); ObjectPtr n = make_object(); GenerationConfig default_config; if (default_config_json_str.defined()) { diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index a0ae4d98f3..a4eda4e395 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -420,7 +420,7 @@ BNFGrammar EBNFParser::Parse(std::string ebnf_string, std::string main_rule) { BNFGrammar BNFJSONParser::Parse(std::string json_string) { auto node = make_object(); - auto grammar_json = json::ParseToJsonObject(json_string); + auto grammar_json = json::ParseToJSONObject(json_string); auto rules_json = json::Lookup(grammar_json, "rules"); for (const auto& rule_json : rules_json) { auto rule_json_obj = rule_json.get(); diff --git a/cpp/support/json_parser.h b/cpp/support/json_parser.h index f71757435a..ef1225081d 100644 --- a/cpp/support/json_parser.h +++ b/cpp/support/json_parser.h @@ -12,6 +12,8 @@ #include +#include "result.h" + namespace mlc { namespace llm { namespace json { @@ -21,52 +23,31 @@ namespace json { * \param json_str The JSON string to parse. * \return The parsed JSON object. */ -picojson::object ParseToJsonObject(const std::string& json_str); - -// Todo(mlc-team): implement "Result" class for JSON parsing with error collection. -/*! - * \brief Parse input JSON string into JSON dict. - * Any error will be dumped to the input error string. - */ -inline std::optional LoadJSONFromString(const std::string& json_str, - std::string* err) { - ICHECK_NOTNULL(err); - picojson::value json; - *err = picojson::parse(json, json_str); - if (!json.is()) { - *err += "The input JSON string does not correspond to a JSON dict."; - return std::nullopt; - } - return json.get(); +inline picojson::object ParseToJSONObject(const std::string& json_str) { + picojson::value result; + std::string err = picojson::parse(result, json_str); + CHECK(err.empty()) << "Failed to parse JSON: err. The JSON string is:" << json_str; + CHECK(result.is()) + << "ValueError: The given string is not a JSON object: " << json_str; + return result.get(); } - /*! - * \brief // Todo(mlc-team): document this function. - * \tparam T - * \param json_obj - * \param field - * \param value - * \param err - * \param required - * \return + * \brief Parse a JSON string to a JSON object. + * \param json_str The JSON string to parse. + * \return The parsed JSON object, or the error message. */ -template -inline bool ParseJSONField(const picojson::object& json_obj, const std::string& field, T& value, - std::string* err, bool required) { - // T can be int, double, bool, string, picojson::array - if (json_obj.count(field)) { - if (!json_obj.at(field).is()) { - *err += "Field " + field + " is not of type " + typeid(T).name() + "\n"; - return false; - } - value = json_obj.at(field).get(); - } else { - if (required) { - *err += "Field " + field + " is required\n"; - return false; - } +inline Result ParseToJSONObjectWithResultReturn(const std::string& json_str) { + using TResult = Result; + picojson::value result; + std::string err = picojson::parse(result, json_str); + if (!err.empty()) { + return TResult::Error("Failed to parse JSON: err. The JSON string is: " + json_str + + ". The error is " + err); + } + if (!result.is()) { + return TResult::Error("ValueError: The given string is not a JSON object: " + json_str); } - return true; + return TResult::Ok(result.get()); } /*! @@ -87,6 +68,109 @@ ValueType Lookup(const picojson::object& json, const std::string& key); */ template ValueType Lookup(const picojson::array& json, int index); +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, the default value is returned. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or the default value if the key doesn't exist or has null value. + */ +template +inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, + const ValueType& default_value) { + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return default_value; + } + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, return std::nullopt. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or std::nullopt if the value doesn't exist or has null value. + */ +template +inline std::optional LookupOptional(const picojson::object& json, + const std::string& key) { + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return std::nullopt; + } + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or the error message. + */ +template +inline Result LookupWithResultReturn(const picojson::object& json, + const std::string& key) { + using TResult = Result; + auto it = json.find(key); + if (it == json.end()) { + return TResult::Error("ValueError: key \"" + key + "\" not found in the JSON object"); + } + if (!it->second.is()) { + return TResult::Error("ValueError: key \"" + key + "\" has unexpected value type."); + } + return TResult::Ok(it->second.get()); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, the default value is returned. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or the default value if the key doesn't exist or has null value + * , or the error message. + */ +template +inline Result LookupOrDefaultWithResultReturn(const picojson::object& json, + const std::string& key, + const ValueType& default_value) { + using TResult = Result; + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return TResult::Ok(default_value); + } + if (!it->second.is()) { + return TResult::Error("ValueError: key \"" + key + "\" has unexpected value type."); + } + return TResult::Ok(it->second.get()); +} +/*! + * \brief Lookup a JSON object by a key, and convert it to a given type. + * If the key doesn't exist or has null value, return std::nullopt. + * \param json The JSON object to look up. + * \param key The key to look up. + * \tparam ValueType The type to be converted to. + * \return The converted value, or std::nullopt if the value doesn't exist or has null value, + * , or the error message. + */ +template +inline Result> LookupOptionalWithResultReturn(const picojson::object& json, + const std::string& key) { + using TResult = Result>; + auto it = json.find(key); + if (it == json.end() || it->second.is()) { + return TResult::Ok(std::nullopt); + } + if (!it->second.is()) { + return TResult::Error("ValueError: key \"" + key + "\" has unexpected value type."); + } + return TResult::Ok(it->second.get()); +} + +// Implementation details /*! \brief ShapeTuple extension to incorporate symbolic shapes. */ struct SymShapeTuple { @@ -112,8 +196,6 @@ struct SymShapeTuple { } }; -// Implementation details - namespace details { inline tvm::runtime::DataType DTypeFromString(const std::string& s) { @@ -149,33 +231,6 @@ inline ValueType Lookup(const picojson::object& json, const std::string& key) { return it->second.get(); } -template -inline ValueType LookupOrDefault(const picojson::object& json, const std::string& key, - const ValueType& default_value) { - auto it = json.find(key); - if (it == json.end()) { - return default_value; - } - - if (it->second.is()) { - return default_value; - } - - CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; - return it->second.get(); -} - -template -inline std::optional LookupOptional(const picojson::object& json, - const std::string& key) { - auto it = json.find(key); - if (it == json.end() || it->second.is()) { - return std::nullopt; - } - CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; - return it->second.get(); -} - template inline ValueType Lookup(const picojson::array& json, int index) { CHECK(index < json.size()) << "IndexError: json::array index out of range"; @@ -205,17 +260,6 @@ inline SymShapeTuple Lookup(const picojson::array& json, int index) { return details::SymShapeTupleFromArray(Lookup(json, index)); } -inline picojson::object ParseToJsonObject(const std::string& json_str) { - picojson::value result; - std::string err = picojson::parse(result, json_str); - if (!err.empty()) { - LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << json_str; - } - CHECK(result.is()) - << "ValueError: The given string is not a JSON object: " << json_str; - return result.get(); -} - } // namespace json } // namespace llm } // namespace mlc