From 40f1d996ecd38f9583f59336bfdd51dfa92bfd35 Mon Sep 17 00:00:00 2001 From: James Date: Mon, 23 Sep 2024 14:56:32 +0700 Subject: [PATCH 1/6] feat: pulling interact with new model.list --- engine/commands/model_import_cmd.cc | 4 +- engine/services/model_service.cc | 14 +++--- engine/utils/cortexso_parser.h | 49 ++++++++++---------- engine/utils/model_callback_utils.h | 70 +++++++++++++++++------------ engine/utils/modellist_utils.cc | 7 +-- engine/utils/modellist_utils.h | 10 +++-- engine/utils/url_parser.h | 4 ++ 7 files changed, 89 insertions(+), 69 deletions(-) diff --git a/engine/commands/model_import_cmd.cc b/engine/commands/model_import_cmd.cc index 193b2488b..3fb047a9d 100644 --- a/engine/commands/model_import_cmd.cc +++ b/engine/commands/model_import_cmd.cc @@ -1,10 +1,8 @@ #include "model_import_cmd.h" #include -#include #include #include "config/gguf_parser.h" #include "config/yaml_config.h" -#include "trantor/utils/Logger.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" #include "utils/modellist_utils.h" @@ -45,7 +43,7 @@ void ModelImportCmd::Exec() { } } catch (const std::exception& e) { - // don't need to remove yml file here, because it's written only if model entry is successfully added, + // don't need to remove yml file here, because it's written only if model entry is successfully added, // remove file here can make it fail with edge case when user try to import new model with existed model_id CLI_LOG("Error importing model path '" + model_path_ + "' with model_id '" + model_handle_ + "': " + e.what()); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 29575dfab..dc6fc3f68 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -95,12 +95,16 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { url_obj.pathParams[2] = "resolve"; } } - + auto author{url_obj.pathParams[0]}; auto model_id{url_obj.pathParams[1]}; auto file_name{url_obj.pathParams.back()}; - auto local_path = - file_manager_utils::GetModelsContainerPath() / model_id / model_id; + if (author == "cortexso") { + return DownloadModelFromCortexso(model_id); + } + + auto local_path{file_manager_utils::GetModelsContainerPath() / + "huggingface.co" / author / model_id / file_name}; try { std::filesystem::create_directories(local_path.parent_path()); @@ -120,10 +124,10 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { .localPath = local_path, }}}}; - auto on_finished = [](const DownloadTask& finishedTask) { + auto on_finished = [&author](const DownloadTask& finishedTask) { CLI_LOG("Model " << finishedTask.id << " downloaded successfully!") auto gguf_download_item = finishedTask.items[0]; - model_callback_utils::ParseGguf(gguf_download_item); + model_callback_utils::ParseGguf(gguf_download_item, author); }; download_service_.AddDownloadTask(downloadTask, on_finished); diff --git a/engine/utils/cortexso_parser.h b/engine/utils/cortexso_parser.h index d4e85bee9..af3372022 100644 --- a/engine/utils/cortexso_parser.h +++ b/engine/utils/cortexso_parser.h @@ -1,5 +1,4 @@ #include -#include #include #include @@ -7,57 +6,57 @@ #include #include "httplib.h" #include "utils/file_manager_utils.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" namespace cortexso_parser { -constexpr static auto kHuggingFaceHost = "https://huggingface.co"; +constexpr static auto kHuggingFaceHost = "huggingface.co"; inline std::optional getDownloadTask( const std::string& modelId, const std::string& branch = "main") { using namespace nlohmann; - std::ostringstream oss; - oss << "/api/models/cortexso/" << modelId << "/tree/" << branch; - const std::string url = oss.str(); + url_parser::Url url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"api", "models", "cortexso", modelId, "tree", branch}}; - std::ostringstream repoAndModelId; - repoAndModelId << "cortexso/" << modelId; - const std::string repoAndModelIdStr = repoAndModelId.str(); - - httplib::Client cli(kHuggingFaceHost); - if (auto res = cli.Get(url)) { + httplib::Client cli(url.GetProtocolAndHost()); + if (auto res = cli.Get(url.GetPathAndQuery())) { if (res->status == httplib::StatusCode::OK_200) { try { auto jsonResponse = json::parse(res->body); - std::vector downloadItems{}; - std::filesystem::path model_container_path = - file_manager_utils::GetModelsContainerPath() / modelId; + std::vector download_items{}; + auto model_container_path = + file_manager_utils::GetModelsContainerPath() / "cortex.so" / + modelId / branch; file_manager_utils::CreateDirectoryRecursively( model_container_path.string()); for (const auto& [key, value] : jsonResponse.items()) { - std::ostringstream downloadUrlOutput; auto path = value["path"].get(); if (path == ".gitattributes" || path == ".gitignore" || path == "README.md") { continue; } - downloadUrlOutput << kHuggingFaceHost << "/" << repoAndModelIdStr - << "/resolve/" << branch << "/" << path; - const std::string download_url = downloadUrlOutput.str(); - auto local_path = model_container_path / path; + url_parser::Url download_url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"cortexso", modelId, "resolve", branch, path}}; - downloadItems.push_back(DownloadItem{.id = path, - .downloadUrl = download_url, - .localPath = local_path}); + auto local_path = model_container_path / path; + download_items.push_back( + DownloadItem{.id = path, + .downloadUrl = download_url.ToFullPath(), + .localPath = local_path}); } - DownloadTask downloadTask{ + DownloadTask download_tasks{ .id = branch == "main" ? modelId : modelId + "-" + branch, .type = DownloadType::Model, - .items = downloadItems}; + .items = download_items}; - return downloadTask; + return download_tasks; } catch (const json::parse_error& e) { CTL_ERR("JSON parse error: {}" << e.what()); } diff --git a/engine/utils/model_callback_utils.h b/engine/utils/model_callback_utils.h index 3a3b0f288..c6e98dd48 100644 --- a/engine/utils/model_callback_utils.h +++ b/engine/utils/model_callback_utils.h @@ -6,27 +6,14 @@ #include "config/gguf_parser.h" #include "config/yaml_config.h" #include "services/download_service.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace model_callback_utils { -inline void WriteYamlOutput(const DownloadItem& modelYmlDownloadItem) { - config::YamlHandler handler; - handler.ModelConfigFromFile(modelYmlDownloadItem.localPath.string()); - config::ModelConfig model_config = handler.GetModelConfig(); - model_config.id = - modelYmlDownloadItem.localPath.parent_path().filename().string(); - - CTL_INF("Updating model config in " - << modelYmlDownloadItem.localPath.string()); - handler.UpdateModelConfig(model_config); - std::string yaml_filename{model_config.id + ".yaml"}; - std::filesystem::path yaml_output = - modelYmlDownloadItem.localPath.parent_path().parent_path() / - yaml_filename; - handler.WriteYamlFile(yaml_output.string()); -} -inline void ParseGguf(const DownloadItem& ggufDownloadItem) { +inline void ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author = nullptr) { config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; gguf_handler.Parse(ggufDownloadItem.localPath.string()); @@ -36,17 +23,27 @@ inline void ParseGguf(const DownloadItem& ggufDownloadItem) { model_config.files = {ggufDownloadItem.localPath.string()}; yaml_handler.UpdateModelConfig(model_config); - std::string yaml_filename{model_config.id + ".yaml"}; - std::filesystem::path yaml_output = - ggufDownloadItem.localPath.parent_path().parent_path() / yaml_filename; - std::filesystem::path yaml_path(ggufDownloadItem.localPath.parent_path() / - "model.yml"); - if (!std::filesystem::exists(yaml_output)) { // if model.yml doesn't exist - yaml_handler.WriteYamlFile(yaml_output.string()); - } + auto yaml_path{ggufDownloadItem.localPath}; + auto yaml_name = yaml_path.replace_extension(".yml"); + if (!std::filesystem::exists(yaml_path)) { yaml_handler.WriteYamlFile(yaml_path.string()); } + + auto url_obj = url_parser::FromUrlString(ggufDownloadItem.downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + + auto author_id = author.has_value() ? author.value() : "cortexso"; + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = model_config.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = yaml_name.string(), + .model_alias = model_config.id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); } inline void DownloadModelCb(const DownloadTask& finishedTask) { @@ -67,12 +64,27 @@ inline void DownloadModelCb(const DownloadTask& finishedTask) { } } - if (model_yml_di != nullptr) { - WriteYamlOutput(*model_yml_di); - } - if (need_parse_gguf && gguf_di != nullptr) { ParseGguf(*gguf_di); } + + if (model_yml_di != nullptr) { + auto url_obj = url_parser::FromUrlString(model_yml_di->downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(model_yml_di->localPath.string()); + auto mc = yaml_handler.GetModelConfig(); + + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = mc.name, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = model_yml_di->localPath.string(), + .model_alias = mc.name, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); + } } } // namespace model_callback_utils diff --git a/engine/utils/modellist_utils.cc b/engine/utils/modellist_utils.cc index 261bf58d5..7e1a43833 100644 --- a/engine/utils/modellist_utils.cc +++ b/engine/utils/modellist_utils.cc @@ -3,10 +3,10 @@ #include #include #include -#include #include #include #include "file_manager_utils.h" + namespace modellist_utils { const std::string ModelListUtils::kModelListPath = (file_manager_utils::GetModelsContainerPath() / @@ -208,7 +208,8 @@ bool ModelListUtils::UpdateModelAlias(const std::string& model_id, }); bool check_alias_unique = std::none_of( entries.begin(), entries.end(), [&](const ModelEntry& entry) { - return (entry.model_id == new_model_alias && entry.model_id != model_id) || + return (entry.model_id == new_model_alias && + entry.model_id != model_id) || entry.model_alias == new_model_alias; }); if (it != entries.end() && check_alias_unique) { @@ -237,4 +238,4 @@ bool ModelListUtils::DeleteModelEntry(const std::string& identifier) { } return false; // Entry not found or not in READY state } -} // namespace modellist_utils \ No newline at end of file +} // namespace modellist_utils diff --git a/engine/utils/modellist_utils.h b/engine/utils/modellist_utils.h index 75a41d880..b7aaca81a 100644 --- a/engine/utils/modellist_utils.h +++ b/engine/utils/modellist_utils.h @@ -1,9 +1,10 @@ #pragma once + #include #include #include #include -#include "logging_utils.h" + namespace modellist_utils { enum class ModelStatus { READY, RUNNING }; @@ -22,7 +23,7 @@ class ModelListUtils { private: mutable std::mutex mutex_; // For thread safety - bool IsUnique(const std::vector& entries, + bool IsUnique(const std::vector& entries, const std::string& model_id, const std::string& model_alias) const; void SaveModelList(const std::vector& entries) const; @@ -40,6 +41,7 @@ class ModelListUtils { bool UpdateModelEntry(const std::string& identifier, const ModelEntry& updated_entry); bool DeleteModelEntry(const std::string& identifier); - bool UpdateModelAlias(const std::string& model_id, const std::string& model_alias); + bool UpdateModelAlias(const std::string& model_id, + const std::string& model_alias); }; -} // namespace modellist_utils \ No newline at end of file +} // namespace modellist_utils diff --git a/engine/utils/url_parser.h b/engine/utils/url_parser.h index 55dd557b8..b8256c92f 100644 --- a/engine/utils/url_parser.h +++ b/engine/utils/url_parser.h @@ -30,6 +30,10 @@ struct Url { } return path; }; + + std::string ToFullPath() const { + return GetProtocolAndHost() + GetPathAndQuery(); + } }; const std::regex url_regex( From 084592b5ce95eff5128f2d1b6a15909493bc4e62 Mon Sep 17 00:00:00 2001 From: James Date: Tue, 24 Sep 2024 09:20:22 +0700 Subject: [PATCH 2/6] feat: return model id when download model success --- engine/services/model_service.cc | 116 ++++++++++++++++++++++++++----- engine/services/model_service.h | 23 +++--- 2 files changed, 112 insertions(+), 27 deletions(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index dc6fc3f68..485bec869 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -2,15 +2,18 @@ #include #include #include +#include "config/gguf_parser.h" +#include "config/yaml_config.h" #include "utils/cli_selection_utils.h" #include "utils/cortexso_parser.h" #include "utils/file_manager_utils.h" #include "utils/huggingface_utils.h" #include "utils/logging_utils.h" -#include "utils/model_callback_utils.h" +#include "utils/modellist_utils.h" #include "utils/string_utils.h" -void ModelService::DownloadModel(const std::string& input) { +std::optional ModelService::DownloadModel( + const std::string& input) { if (input.empty()) { throw std::runtime_error( "Input must be Cortex Model Hub handle or HuggingFace url!"); @@ -32,15 +35,15 @@ void ModelService::DownloadModel(const std::string& input) { return DownloadModelByModelName(model_name); } - DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); CLI_LOG("Model " << model_name << " downloaded successfully!") - return; + return DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); } return DownloadModelByModelName(input); } -void ModelService::DownloadModelByModelName(const std::string& modelName) { +std::optional ModelService::DownloadModelByModelName( + const std::string& modelName) { try { auto branches = huggingface_utils::GetModelRepositoryBranches("cortexso", modelName); @@ -52,12 +55,13 @@ void ModelService::DownloadModelByModelName(const std::string& modelName) { } if (options.empty()) { CLI_LOG("No variant found"); - return; + return std::nullopt; } auto selection = cli_selection_utils::PrintSelection(options); - DownloadModelFromCortexso(modelName, selection.value()); + return DownloadModelFromCortexso(modelName, selection.value()); } catch (const std::runtime_error& e) { CLI_LOG("Error downloading model, " << e.what()); + return std::nullopt; } } @@ -87,7 +91,8 @@ std::optional ModelService::GetDownloadedModel( return std::nullopt; } -void ModelService::DownloadModelByDirectUrl(const std::string& url) { +std::optional ModelService::DownloadModelByDirectUrl( + const std::string& url) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { @@ -103,6 +108,9 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { return DownloadModelFromCortexso(model_id); } + std::string huggingFaceHost{kHuggingFaceHost}; + std::string unique_model_id{huggingFaceHost + "/" + author + "/" + model_id + + "/" + file_name}; auto local_path{file_manager_utils::GetModelsContainerPath() / "huggingface.co" / author / model_id / file_name}; @@ -119,33 +127,68 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { auto downloadTask{DownloadTask{.id = model_id, .type = DownloadType::Model, .items = {DownloadItem{ - .id = url_obj.pathParams.back(), + .id = unique_model_id, .downloadUrl = download_url, .localPath = local_path, }}}}; - auto on_finished = [&author](const DownloadTask& finishedTask) { + auto on_finished = [&](const DownloadTask& finishedTask) { CLI_LOG("Model " << finishedTask.id << " downloaded successfully!") auto gguf_download_item = finishedTask.items[0]; - model_callback_utils::ParseGguf(gguf_download_item, author); + ParseGguf(gguf_download_item, author); }; download_service_.AddDownloadTask(downloadTask, on_finished); + return unique_model_id; } -void ModelService::DownloadModelFromCortexso(const std::string& name, - const std::string& branch) { +std::optional ModelService::DownloadModelFromCortexso( + const std::string& name, const std::string& branch) { + auto downloadTask = cortexso_parser::getDownloadTask(name, branch); if (downloadTask.has_value()) { - DownloadService().AddDownloadTask(downloadTask.value(), - model_callback_utils::DownloadModelCb); - CLI_LOG("Model " << name << " downloaded successfully!") + std::string model_id{name + ":" + branch}; + DownloadService().AddDownloadTask( + downloadTask.value(), [&](const DownloadTask& finishedTask) { + const DownloadItem* model_yml_item = nullptr; + auto need_parse_gguf = true; + + for (const auto& item : finishedTask.items) { + if (item.localPath.filename().string() == "model.yml") { + model_yml_item = &item; + } + } + + if (model_yml_item != nullptr) { + auto url_obj = + url_parser::FromUrlString(model_yml_item->downloadUrl); + CTL_INF("Adding model to modellist with branch: " << branch); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile( + model_yml_item->localPath.string()); + auto mc = yaml_handler.GetModelConfig(); + + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = model_yml_item->localPath.string(), + .model_alias = model_id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); + } + }); + + CLI_LOG("Model " << model_id << " downloaded successfully!") + return model_id; } else { CTL_ERR("Model not found"); + return std::nullopt; } } -void ModelService::DownloadHuggingFaceGgufModel( +std::optional ModelService::DownloadHuggingFaceGgufModel( const std::string& author, const std::string& modelName, std::optional fileName) { auto repo_info = @@ -153,7 +196,7 @@ void ModelService::DownloadHuggingFaceGgufModel( if (!repo_info.has_value()) { // throw is better? CTL_ERR("Model not found"); - return; + return std::nullopt; } if (!repo_info->gguf.has_value()) { @@ -172,5 +215,40 @@ void ModelService::DownloadHuggingFaceGgufModel( auto download_url = huggingface_utils::GetDownloadableUrl(author, modelName, selection.value()); - DownloadModelByDirectUrl(download_url); + return DownloadModelByDirectUrl(download_url); +} + +void ModelService::ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author) const { + + config::GGUFHandler gguf_handler; + config::YamlHandler yaml_handler; + gguf_handler.Parse(ggufDownloadItem.localPath.string()); + config::ModelConfig model_config = gguf_handler.GetModelConfig(); + model_config.id = + ggufDownloadItem.localPath.parent_path().filename().string(); + model_config.files = {ggufDownloadItem.localPath.string()}; + yaml_handler.UpdateModelConfig(model_config); + + auto yaml_path{ggufDownloadItem.localPath}; + auto yaml_name = yaml_path.replace_extension(".yml"); + + if (!std::filesystem::exists(yaml_path)) { + yaml_handler.WriteYamlFile(yaml_path.string()); + } + + auto url_obj = url_parser::FromUrlString(ggufDownloadItem.downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + + auto author_id = author.has_value() ? author.value() : "cortexso"; + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = yaml_name.string(), + .model_alias = ggufDownloadItem.id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry, true); } diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 06212aaee..4237f1b17 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -8,27 +8,34 @@ class ModelService { public: ModelService() : download_service_{DownloadService()} {}; - void DownloadModel(const std::string& input); + /** + * Return model id if download successfully + */ + std::optional DownloadModel(const std::string& input); std::optional GetDownloadedModel( const std::string& modelId) const; private: - void DownloadModelByDirectUrl(const std::string& url); + std::optional DownloadModelByDirectUrl(const std::string& url); - void DownloadModelFromCortexso(const std::string& name, - const std::string& branch = "main"); + std::optional DownloadModelFromCortexso( + const std::string& name, const std::string& branch = "main"); /** * Handle downloading model which have following pattern: author/model_name */ - void DownloadHuggingFaceGgufModel(const std::string& author, - const std::string& modelName, - std::optional fileName); + std::optional DownloadHuggingFaceGgufModel( + const std::string& author, const std::string& modelName, + std::optional fileName); - void DownloadModelByModelName(const std::string& modelName); + std::optional DownloadModelByModelName( + const std::string& modelName); DownloadService download_service_; + void ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author = nullptr) const; + constexpr auto static kHuggingFaceHost = "huggingface.co"; }; From e89a72fd70757b4e2f46a1df4ebe5616c6eaba09 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 24 Sep 2024 10:45:05 +0700 Subject: [PATCH 3/6] feat: model delete for new model.list (#1317) * feat: models delete for new models data folder structure * feat: delete model --- engine/commands/model_del_cmd.cc | 76 ++++++++++++++------------------ engine/commands/model_del_cmd.h | 2 +- 2 files changed, 35 insertions(+), 43 deletions(-) diff --git a/engine/commands/model_del_cmd.cc b/engine/commands/model_del_cmd.cc index f2023f5c1..7f6b6d32a 100644 --- a/engine/commands/model_del_cmd.cc +++ b/engine/commands/model_del_cmd.cc @@ -2,55 +2,47 @@ #include "cmd_info.h" #include "config/yaml_config.h" #include "utils/file_manager_utils.h" +#include "utils/modellist_utils.h" namespace commands { -bool ModelDelCmd::Exec(const std::string& model_id) { - // TODO this implentation may be changed after we have a decision - // on https://github.com/janhq/cortex.cpp/issues/1154 but the logic should be similar - CmdInfo ci(model_id); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - auto models_path = file_manager_utils::GetModelsContainerPath(); - if (std::filesystem::exists(models_path) && - std::filesystem::is_directory(models_path)) { - // Iterate through directory - for (const auto& entry : std::filesystem::directory_iterator(models_path)) { - if (entry.is_regular_file() && entry.path().extension() == ".yaml") { - try { - config::YamlHandler handler; - handler.ModelConfigFromFile(entry.path().string()); - auto cfg = handler.GetModelConfig(); - if (entry.path().stem().string() == model_file) { - // Delete data - if (cfg.files.size() > 0) { - std::filesystem::path f(cfg.files[0]); - auto rel = std::filesystem::relative(f, models_path); - // Only delete model data if it is stored in our models folder - if (!rel.empty()) { - if (cfg.engine == "cortex.llamacpp") { - std::filesystem::remove_all(f.parent_path()); - } else { - std::filesystem::remove_all(f); - } - } - } +bool ModelDelCmd::Exec(const std::string& model_handle) { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; - // Delete yaml file - std::filesystem::remove(entry); - CLI_LOG("The model " << model_id << " was deleted"); - return true; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + // Remove yaml file + std::filesystem::remove(model_entry.path_to_model_yaml); + // Remove model files if they are not imported locally + if (model_entry.branch_name != "imported") { + if (mc.files.size() > 0) { + if (mc.engine == "cortex.llamacpp") { + for (auto& file : mc.files) { + std::filesystem::path gguf_p(file); + std::filesystem::remove(gguf_p); } - } catch (const std::exception& e) { - CTL_WRN("Error reading yaml file '" << entry.path().string() - << "': " << e.what()); - return false; + } else { + std::filesystem::path f(mc.files[0]); + std::filesystem::remove_all(f); } + } else { + CTL_WRN("model config files are empty!"); } } - } - - CLI_LOG("Model does not exist: " << model_id); - return false; + // update model.list + if (modellist_handler.DeleteModelEntry(model_handle)) { + CLI_LOG("The model " << model_handle << " was deleted"); + return true; + } else { + CTL_ERR("Could not delete model: " << model_handle); + return false; + } + } catch (const std::exception& e) { + CLI_LOG("Fail to delete model with ID '" + model_handle + "': " + e.what()); + false; + } } } // namespace commands \ No newline at end of file diff --git a/engine/commands/model_del_cmd.h b/engine/commands/model_del_cmd.h index 0dd41f74e..437564208 100644 --- a/engine/commands/model_del_cmd.h +++ b/engine/commands/model_del_cmd.h @@ -6,6 +6,6 @@ namespace commands { class ModelDelCmd { public: - bool Exec(const std::string& model_id); + bool Exec(const std::string& model_handle); }; } \ No newline at end of file From 7c31b711616eeea35085cbd0664f7f02adf4a7c4 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 24 Sep 2024 11:22:08 +0700 Subject: [PATCH 4/6] feat: cortex chat/run/models start with new model data structure (#1301) * feat: cortex chat/run/models start with new model data structure * temp * fix: use model_id from model service and std::get_line * f:m --- engine/commands/chat_cmd.cc | 34 ++++-- engine/commands/chat_cmd.h | 9 +- engine/commands/model_start_cmd.cc | 56 ++++++---- engine/commands/model_start_cmd.h | 9 +- engine/commands/model_status_cmd.cc | 17 +++ engine/commands/model_status_cmd.h | 2 + engine/commands/run_cmd.cc | 101 +++++++++--------- engine/commands/run_cmd.h | 6 +- engine/controllers/command_line_parser.cc | 29 ++--- engine/main.cc | 6 +- engine/services/download_service.cc | 4 +- .../test/components/test_modellist_utils.cc | 8 ++ engine/utils/cli_selection_utils.h | 2 +- engine/utils/modellist_utils.cc | 15 +++ engine/utils/modellist_utils.h | 1 + engine/utils/url_parser.h | 2 + 16 files changed, 183 insertions(+), 118 deletions(-) diff --git a/engine/commands/chat_cmd.cc b/engine/commands/chat_cmd.cc index da232a321..e4d0eda3d 100644 --- a/engine/commands/chat_cmd.cc +++ b/engine/commands/chat_cmd.cc @@ -6,6 +6,7 @@ #include "server_start_cmd.h" #include "trantor/utils/Logger.h" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace commands { namespace { @@ -36,23 +37,36 @@ struct ChunkParser { } }; -ChatCmd::ChatCmd(std::string host, int port, const config::ModelConfig& mc) - : host_(std::move(host)), port_(port), mc_(mc) {} +void ChatCmd::Exec(const std::string& host, int port, + const std::string& model_handle, std::string msg) { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + Exec(host, port, mc, std::move(msg)); + } catch (const std::exception& e) { + CLI_LOG("Fail to start model information with ID '" + model_handle + + "': " + e.what()); + } +} -void ChatCmd::Exec(std::string msg) { +void ChatCmd::Exec(const std::string& host, int port, + const config::ModelConfig& mc, std::string msg) { + auto address = host + ":" + std::to_string(port); // Check if server is started { - if (!commands::IsServerAlive(host_, port_)) { + if (!commands::IsServerAlive(host, port)) { CLI_LOG("Server is not started yet, please run `" << commands::GetCortexBinary() << " start` to start server!"); return; } } - auto address = host_ + ":" + std::to_string(port_); // Only check if llamacpp engine - if ((mc_.engine.find("llamacpp") != std::string::npos) && - !commands::ModelStatusCmd().IsLoaded(host_, port_, mc_)) { + if ((mc.engine.find("llamacpp") != std::string::npos) && + !commands::ModelStatusCmd().IsLoaded(host, port, mc)) { CLI_LOG("Model is not loaded yet!"); return; } @@ -78,12 +92,12 @@ void ChatCmd::Exec(std::string msg) { new_data["role"] = kUser; new_data["content"] = user_input; histories_.push_back(std::move(new_data)); - json_data["engine"] = mc_.engine; + json_data["engine"] = mc.engine; json_data["messages"] = histories_; - json_data["model"] = mc_.name; + json_data["model"] = mc.name; //TODO: support non-stream json_data["stream"] = true; - json_data["stop"] = mc_.stop; + json_data["stop"] = mc.stop; auto data_str = json_data.dump(); // std::cout << data_str << std::endl; cli.set_read_timeout(std::chrono::seconds(60)); diff --git a/engine/commands/chat_cmd.h b/engine/commands/chat_cmd.h index d5b48927c..596cfce2d 100644 --- a/engine/commands/chat_cmd.h +++ b/engine/commands/chat_cmd.h @@ -7,13 +7,12 @@ namespace commands { class ChatCmd { public: - ChatCmd(std::string host, int port, const config::ModelConfig& mc); - void Exec(std::string msg); + void Exec(const std::string& host, int port, const std::string& model_handle, + std::string msg); + void Exec(const std::string& host, int port, const config::ModelConfig& mc, + std::string msg); private: - std::string host_; - int port_; - const config::ModelConfig& mc_; std::vector histories_; }; } // namespace commands \ No newline at end of file diff --git a/engine/commands/model_start_cmd.cc b/engine/commands/model_start_cmd.cc index 1a96b4fee..1340614d9 100644 --- a/engine/commands/model_start_cmd.cc +++ b/engine/commands/model_start_cmd.cc @@ -7,43 +7,59 @@ #include "trantor/utils/Logger.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace commands { -ModelStartCmd::ModelStartCmd(std::string host, int port, - const config::ModelConfig& mc) - : host_(std::move(host)), port_(port), mc_(mc) {} +bool ModelStartCmd::Exec(const std::string& host, int port, + const std::string& model_handle) { -bool ModelStartCmd::Exec() { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + return Exec(host, port, mc); + } catch (const std::exception& e) { + CLI_LOG("Fail to start model information with ID '" + model_handle + + "': " + e.what()); + return false; + } +} + +bool ModelStartCmd::Exec(const std::string& host, int port, + const config::ModelConfig& mc) { // Check if server is started - if (!commands::IsServerAlive(host_, port_)) { + if (!commands::IsServerAlive(host, port)) { CLI_LOG("Server is not started yet, please run `" << commands::GetCortexBinary() << " start` to start server!"); return false; } + // Only check for llamacpp for now - if ((mc_.engine.find("llamacpp") != std::string::npos) && - commands::ModelStatusCmd().IsLoaded(host_, port_, mc_)) { + if ((mc.engine.find("llamacpp") != std::string::npos) && + commands::ModelStatusCmd().IsLoaded(host, port, mc)) { CLI_LOG("Model has already been started!"); return true; } - httplib::Client cli(host_ + ":" + std::to_string(port_)); + httplib::Client cli(host + ":" + std::to_string(port)); nlohmann::json json_data; - if (mc_.files.size() > 0) { + if (mc.files.size() > 0) { // TODO(sang) support multiple files - json_data["model_path"] = mc_.files[0]; + json_data["model_path"] = mc.files[0]; } else { LOG_WARN << "model_path is empty"; return false; } - json_data["model"] = mc_.name; - json_data["system_prompt"] = mc_.system_template; - json_data["user_prompt"] = mc_.user_template; - json_data["ai_prompt"] = mc_.ai_template; - json_data["ctx_len"] = mc_.ctx_len; - json_data["stop"] = mc_.stop; - json_data["engine"] = mc_.engine; + json_data["model"] = mc.name; + json_data["system_prompt"] = mc.system_template; + json_data["user_prompt"] = mc.user_template; + json_data["ai_prompt"] = mc.ai_template; + json_data["ctx_len"] = mc.ctx_len; + json_data["stop"] = mc.stop; + json_data["engine"] = mc.engine; auto data_str = json_data.dump(); cli.set_read_timeout(std::chrono::seconds(60)); @@ -52,13 +68,17 @@ bool ModelStartCmd::Exec() { if (res) { if (res->status == httplib::StatusCode::OK_200) { CLI_LOG("Model loaded!"); + return true; + } else { + CTL_ERR("Model failed to load with status code: " << res->status); + return false; } } else { auto err = res.error(); CTL_ERR("HTTP error: " << httplib::to_string(err)); return false; } - return true; + return false; } }; // namespace commands diff --git a/engine/commands/model_start_cmd.h b/engine/commands/model_start_cmd.h index 26daf9d0e..fbf3c0645 100644 --- a/engine/commands/model_start_cmd.h +++ b/engine/commands/model_start_cmd.h @@ -6,13 +6,8 @@ namespace commands { class ModelStartCmd { public: - explicit ModelStartCmd(std::string host, int port, - const config::ModelConfig& mc); - bool Exec(); + bool Exec(const std::string& host, int port, const std::string& model_handle); - private: - std::string host_; - int port_; - const config::ModelConfig& mc_; + bool Exec(const std::string& host, int port, const config::ModelConfig& mc); }; } // namespace commands diff --git a/engine/commands/model_status_cmd.cc b/engine/commands/model_status_cmd.cc index f54aa9100..e6ba9bbe0 100644 --- a/engine/commands/model_status_cmd.cc +++ b/engine/commands/model_status_cmd.cc @@ -3,8 +3,25 @@ #include "httplib.h" #include "nlohmann/json.hpp" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace commands { +bool ModelStatusCmd::IsLoaded(const std::string& host, int port, + const std::string& model_handle) { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + return IsLoaded(host, port, mc); + } catch (const std::exception& e) { + CLI_LOG("Fail to get model status with ID '" + model_handle + + "': " + e.what()); + return false; + } +} + bool ModelStatusCmd::IsLoaded(const std::string& host, int port, const config::ModelConfig& mc) { httplib::Client cli(host + ":" + std::to_string(port)); diff --git a/engine/commands/model_status_cmd.h b/engine/commands/model_status_cmd.h index 2ef44a41d..273d73ef9 100644 --- a/engine/commands/model_status_cmd.h +++ b/engine/commands/model_status_cmd.h @@ -6,6 +6,8 @@ namespace commands { class ModelStatusCmd { public: + bool IsLoaded(const std::string& host, int port, + const std::string& model_handle); bool IsLoaded(const std::string& host, int port, const config::ModelConfig& mc); }; diff --git a/engine/commands/run_cmd.cc b/engine/commands/run_cmd.cc index 16b496b0d..d17d91e9f 100644 --- a/engine/commands/run_cmd.cc +++ b/engine/commands/run_cmd.cc @@ -5,71 +5,76 @@ #include "model_start_cmd.h" #include "model_status_cmd.h" #include "server_start_cmd.h" +#include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" - +#include "utils/modellist_utils.h" namespace commands { void RunCmd::Exec() { + std::optional model_id = model_handle_; + + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); - CmdInfo ci(model_id_); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - // TODO should we clean all resource if something fails? - // Check if model existed. If not, download it - { - auto model_conf = model_service_.GetDownloadedModel(model_file + ".yaml"); - if (!model_conf.has_value()) { - model_service_.DownloadModel(model_id_); - } - } - // Check if engine existed. If not, download it + // Download model if it does not exist { - auto required_engine = engine_service_.GetEngineInfo(ci.engine_name); - if (!required_engine.has_value()) { - throw std::runtime_error("Engine not found: " + ci.engine_name); - } - if (required_engine.value().status == EngineService::kIncompatible) { - throw std::runtime_error("Engine " + ci.engine_name + " is incompatible"); - } - if (required_engine.value().status == EngineService::kNotInstalled) { - engine_service_.InstallEngine(ci.engine_name); + if (!modellist_handler.HasModel(model_handle_)) { + model_id = model_service_.DownloadModel(model_handle_); + if (!model_id.has_value()) { + CTL_ERR("Error: Could not get model_id from handle: " << model_handle_); + return; + } else { + CTL_INF("model_id: " << model_id.value()); + } } } - // Start server if it is not running - { - if (!commands::IsServerAlive(host_, port_)) { - CLI_LOG("Starting server ..."); - commands::ServerStartCmd ssc; - if (!ssc.Exec(host_, port_)) { - return; + try { + auto model_entry = modellist_handler.GetModelInfo(*model_id); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + + // Check if engine existed. If not, download it + { + auto required_engine = engine_service_.GetEngineInfo(mc.engine); + if (!required_engine.has_value()) { + throw std::runtime_error("Engine not found: " + mc.engine); + } + if (required_engine.value().status == EngineService::kIncompatible) { + throw std::runtime_error("Engine " + mc.engine + " is incompatible"); + } + if (required_engine.value().status == EngineService::kNotInstalled) { + engine_service_.InstallEngine(mc.engine); } } - } - config::YamlHandler yaml_handler; - yaml_handler.ModelConfigFromFile( - file_manager_utils::GetModelsContainerPath().string() + "/" + model_file + - ".yaml"); - auto mc = yaml_handler.GetModelConfig(); + // Start server if it is not running + { + if (!commands::IsServerAlive(host_, port_)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host_, port_)) { + return; + } + } + } - // Always start model if not llamacpp - // If it is llamacpp, then check model status first - { - if ((mc.engine.find("llamacpp") == std::string::npos) || - !commands::ModelStatusCmd().IsLoaded(host_, port_, mc)) { - ModelStartCmd msc(host_, port_, mc); - if (!msc.Exec()) { - return; + // Always start model if not llamacpp + // If it is llamacpp, then check model status first + { + if ((mc.engine.find("llamacpp") == std::string::npos) || + !commands::ModelStatusCmd().IsLoaded(host_, port_, mc)) { + if (!ModelStartCmd().Exec(host_, port_, mc)) { + return; + } } } - } - // Chat - { - ChatCmd cc(host_, port_, mc); - cc.Exec(""); + // Chat + ChatCmd().Exec(host_, port_, mc, ""); + } catch (const std::exception& e) { + CLI_LOG("Fail to run model with ID '" + model_handle_ + "': " + e.what()); } } }; // namespace commands diff --git a/engine/commands/run_cmd.h b/engine/commands/run_cmd.h index c862926a6..136800102 100644 --- a/engine/commands/run_cmd.h +++ b/engine/commands/run_cmd.h @@ -6,10 +6,10 @@ namespace commands { class RunCmd { public: - explicit RunCmd(std::string host, int port, std::string model_id) + explicit RunCmd(std::string host, int port, std::string model_handle) : host_{std::move(host)}, port_{port}, - model_id_{std::move(model_id)}, + model_handle_{std::move(model_handle)}, model_service_{ModelService()} {}; void Exec(); @@ -17,7 +17,7 @@ class RunCmd { private: std::string host_; int port_; - std::string model_id_; + std::string model_handle_; ModelService model_service_; EngineService engine_service_; diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index 31ace9ffd..6073cbbb3 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -131,17 +131,10 @@ void CommandLineParser::SetupCommonCommands() { CLI_LOG(chat_cmd->help()); return; } - commands::CmdInfo ci(cml_data_.model_id); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - config::YamlHandler yaml_handler; - yaml_handler.ModelConfigFromFile( - file_manager_utils::GetModelsContainerPath().string() + "/" + - model_file + ".yaml"); - commands::ChatCmd cc(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - yaml_handler.GetModelConfig()); - cc.Exec(cml_data_.msg); + + commands::ChatCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, + cml_data_.msg); }); } @@ -177,17 +170,9 @@ void CommandLineParser::SetupModelCommands() { CLI_LOG(model_start_cmd->help()); return; }; - commands::CmdInfo ci(cml_data_.model_id); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - config::YamlHandler yaml_handler; - yaml_handler.ModelConfigFromFile( - file_manager_utils::GetModelsContainerPath().string() + "/" + - model_file + ".yaml"); - commands::ModelStartCmd msc(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - yaml_handler.GetModelConfig()); - msc.Exec(); + commands::ModelStartCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + cml_data_.model_id); }); auto stop_model_cmd = diff --git a/engine/main.cc b/engine/main.cc index e7fe9bd22..c461342c9 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -29,8 +29,8 @@ void RunServer() { auto config = file_manager_utils::GetCortexConfig(); - LOG_INFO << "Host: " << config.apiServerHost - << " Port: " << config.apiServerPort << "\n"; + std::cout << "Host: " << config.apiServerHost + << " Port: " << config.apiServerPort << "\n"; // Create logs/ folder and setup log to file std::filesystem::create_directories( @@ -46,6 +46,8 @@ void RunServer() { asyncFileLogger.output_(msg, len); }, [&]() { asyncFileLogger.flush(); }); + LOG_INFO << "Host: " << config.apiServerHost + << " Port: " << config.apiServerPort << "\n"; // Number of cortex.cpp threads // if (argc > 1) { // thread_num = std::atoi(argv[1]); diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index 1cf8b68c4..496d01116 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -113,7 +113,7 @@ void DownloadService::Download(const std::string& download_id, << " need to be downloaded."); std::cout << "Continue download [Y/n]: " << std::flush; std::string answer{""}; - std::cin >> answer; + std::getline(std::cin, answer); if (answer == "Y" || answer == "y" || answer.empty()) { mode = "ab"; CLI_LOG("Resuming download.."); @@ -126,7 +126,7 @@ void DownloadService::Download(const std::string& download_id, std::cout << "Re-download? [Y/n]: " << std::flush; std::string answer = ""; - std::cin >> answer; + std::getline(std::cin, answer); if (answer == "Y" || answer == "y" || answer.empty()) { CLI_LOG("Re-downloading.."); } else { diff --git a/engine/test/components/test_modellist_utils.cc b/engine/test/components/test_modellist_utils.cc index 2a7abc05a..68b06483d 100644 --- a/engine/test/components/test_modellist_utils.cc +++ b/engine/test/components/test_modellist_utils.cc @@ -120,4 +120,12 @@ TEST_F(ModelListUtilsTestSuite, TestUpdateModelAlias) { // Clean up model_list_.DeleteModelEntry("test_model_id"); model_list_.DeleteModelEntry("another_model_id"); +} + +TEST_F(ModelListUtilsTestSuite, TestHasModel) { + model_list_.AddModelEntry(kTestModel); + + EXPECT_TRUE(model_list_.HasModel("test_model_id")); + EXPECT_TRUE(model_list_.HasModel("test_alias")); + EXPECT_FALSE(model_list_.HasModel("non_existent_model")); } \ No newline at end of file diff --git a/engine/utils/cli_selection_utils.h b/engine/utils/cli_selection_utils.h index d3848c5bb..0c2453478 100644 --- a/engine/utils/cli_selection_utils.h +++ b/engine/utils/cli_selection_utils.h @@ -20,7 +20,7 @@ inline std::optional PrintSelection( std::string selection{""}; PrintMenu(options); std::cout << "Select an option (" << 1 << "-" << options.size() << "): "; - std::cin >> selection; + std::getline(std::cin, selection); if (selection.empty()) { return std::nullopt; diff --git a/engine/utils/modellist_utils.cc b/engine/utils/modellist_utils.cc index 7e1a43833..d577519f3 100644 --- a/engine/utils/modellist_utils.cc +++ b/engine/utils/modellist_utils.cc @@ -238,4 +238,19 @@ bool ModelListUtils::DeleteModelEntry(const std::string& identifier) { } return false; // Entry not found or not in READY state } + +bool ModelListUtils::HasModel(const std::string& identifier) const { + std::lock_guard lock(mutex_); + auto entries = LoadModelList(); + auto it = std::find_if( + entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) { + return entry.model_id == identifier || entry.model_alias == identifier; + }); + + if (it != entries.end()) { + return true; + } else { + return false; + } +} } // namespace modellist_utils diff --git a/engine/utils/modellist_utils.h b/engine/utils/modellist_utils.h index b7aaca81a..113591f25 100644 --- a/engine/utils/modellist_utils.h +++ b/engine/utils/modellist_utils.h @@ -43,5 +43,6 @@ class ModelListUtils { bool DeleteModelEntry(const std::string& identifier); bool UpdateModelAlias(const std::string& model_id, const std::string& model_alias); + bool HasModel(const std::string& identifier) const; }; } // namespace modellist_utils diff --git a/engine/utils/url_parser.h b/engine/utils/url_parser.h index 97d499a97..90b62143e 100644 --- a/engine/utils/url_parser.h +++ b/engine/utils/url_parser.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include From fb55754ea38a833401ab49ec9bdedac338bb0f9d Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:28:39 +0700 Subject: [PATCH 5/6] Model update command/api (#1309) * Model update command/api * fix: resume download failed Signed-off-by: James * fix: align github syntax for cuda (#1316) * fix: require sudo for cortex update (#1318) * fix: require sudo for cortex update * fix: comment * refactor code * Format code * Add clean up when finish test * remove model.list after finish test * Fix windows CI build --------- Signed-off-by: James Co-authored-by: James Co-authored-by: vansangpfiev --- engine/commands/model_upd_cmd.cc | 127 ++++++++++++++++++ engine/commands/model_upd_cmd.h | 30 +++++ engine/config/model_config.h | 108 +++++++++++++++ engine/controllers/command_line_parser.cc | 79 ++++++++++- engine/controllers/command_line_parser.h | 14 +- engine/controllers/models.cc | 44 +++++- engine/controllers/models.h | 8 +- engine/services/download_service.cc | 60 +++++++-- engine/services/download_service.h | 3 + engine/services/engine_service.cc | 9 +- .../test/components/test_modellist_utils.cc | 3 + 11 files changed, 460 insertions(+), 25 deletions(-) create mode 100644 engine/commands/model_upd_cmd.cc create mode 100644 engine/commands/model_upd_cmd.h diff --git a/engine/commands/model_upd_cmd.cc b/engine/commands/model_upd_cmd.cc new file mode 100644 index 000000000..eb7edd3df --- /dev/null +++ b/engine/commands/model_upd_cmd.cc @@ -0,0 +1,127 @@ +#include "model_upd_cmd.h" + +#include "utils/logging_utils.h" + +namespace commands { + +ModelUpdCmd::ModelUpdCmd(std::string model_handle) + : model_handle_(std::move(model_handle)) {} + +void ModelUpdCmd::Exec( + const std::unordered_map& options) { + try { + auto model_entry = model_list_utils_.GetModelInfo(model_handle_); + yaml_handler_.ModelConfigFromFile(model_entry.path_to_model_yaml); + model_config_ = yaml_handler_.GetModelConfig(); + + for (const auto& [key, value] : options) { + if (!value.empty()) { + UpdateConfig(key, value); + } + } + + yaml_handler_.UpdateModelConfig(model_config_); + yaml_handler_.WriteYamlFile(model_entry.path_to_model_yaml); + CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!"); + } catch (const std::exception& e) { + CLI_LOG("Failed to update model with model ID '" + model_handle_ + + "': " + e.what()); + } +} + +void ModelUpdCmd::UpdateConfig(const std::string& key, + const std::string& value) { + static const std::unordered_map< + std::string, + std::function> + updaters = { + {"name", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.name = v; + }}, + {"model", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.model = v; + }}, + {"version", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.version = v; + }}, + {"stop", &ModelUpdCmd::UpdateVectorField}, + {"top_p", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.top_p = f; }); + }}, + {"temperature", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.temperature = f; + }); + }}, + {"frequency_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.frequency_penalty = f; + }); + }}, + {"presence_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.presence_penalty = f; + }); + }}, + {"max_tokens", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.max_tokens = static_cast(f); + }); + }}, + {"stream", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.stream = b; }); + }}, + // Add more fields here... + }; + + if (auto it = updaters.find(key); it != updaters.end()) { + it->second(this, key, value); + LogUpdate(key, value); + } +} + +void ModelUpdCmd::UpdateVectorField(const std::string& key, + const std::string& value) { + std::vector tokens; + std::istringstream iss(value); + std::string token; + while (std::getline(iss, token, ',')) { + tokens.push_back(token); + } + model_config_.stop = tokens; +} + +void ModelUpdCmd::UpdateNumericField(const std::string& key, + const std::string& value, + std::function setter) { + try { + float numericValue = std::stof(value); + setter(numericValue); + } catch (const std::exception& e) { + CLI_LOG("Failed to parse numeric value for " << key << ": " << e.what()); + } +} + +void ModelUpdCmd::UpdateBooleanField(const std::string& key, + const std::string& value, + std::function setter) { + bool boolValue = (value == "true" || value == "1"); + setter(boolValue); +} + +void ModelUpdCmd::LogUpdate(const std::string& key, const std::string& value) { + CLI_LOG("Updated " << key << " to: " << value); +} + +} // namespace commands \ No newline at end of file diff --git a/engine/commands/model_upd_cmd.h b/engine/commands/model_upd_cmd.h new file mode 100644 index 000000000..51f5a88d3 --- /dev/null +++ b/engine/commands/model_upd_cmd.h @@ -0,0 +1,30 @@ +#pragma once +#include +#include +#include +#include +#include +#include "config/model_config.h" +#include "utils/modellist_utils.h" +#include "config/yaml_config.h" +namespace commands { +class ModelUpdCmd { + public: + ModelUpdCmd(std::string model_handle); + void Exec(const std::unordered_map& options); + + private: + std::string model_handle_; + config::ModelConfig model_config_; + config::YamlHandler yaml_handler_; + modellist_utils::ModelListUtils model_list_utils_; + + void UpdateConfig(const std::string& key, const std::string& value); + void UpdateVectorField(const std::string& key, const std::string& value); + void UpdateNumericField(const std::string& key, const std::string& value, + std::function setter); + void UpdateBooleanField(const std::string& key, const std::string& value, + std::function setter); + void LogUpdate(const std::string& key, const std::string& value); +}; +} // namespace commands \ No newline at end of file diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 74410db52..a65114ca7 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -58,7 +58,115 @@ struct ModelConfig { int n_probs = 0; int min_keep = 0; std::string grammar; + + void FromJson(const Json::Value& json) { + // do now allow to update ID and model field because it is unique identifier + // if (json.isMember("id")) + // id = json["id"].asString(); + if (json.isMember("name")) + name = json["name"].asString(); + // if (json.isMember("model")) + // model = json["model"].asString(); + if (json.isMember("version")) + version = json["version"].asString(); + if (json.isMember("stop") && json["stop"].isArray()) { + stop.clear(); + for (const auto& s : json["stop"]) { + stop.push_back(s.asString()); + } + } + + if (json.isMember("stream")) + stream = json["stream"].asBool(); + if (json.isMember("top_p")) + top_p = json["top_p"].asFloat(); + if (json.isMember("temperature")) + temperature = json["temperature"].asFloat(); + if (json.isMember("frequency_penalty")) + frequency_penalty = json["frequency_penalty"].asFloat(); + if (json.isMember("presence_penalty")) + presence_penalty = json["presence_penalty"].asFloat(); + if (json.isMember("max_tokens")) + max_tokens = json["max_tokens"].asInt(); + if (json.isMember("seed")) + seed = json["seed"].asInt(); + if (json.isMember("dynatemp_range")) + dynatemp_range = json["dynatemp_range"].asFloat(); + if (json.isMember("dynatemp_exponent")) + dynatemp_exponent = json["dynatemp_exponent"].asFloat(); + if (json.isMember("top_k")) + top_k = json["top_k"].asInt(); + if (json.isMember("min_p")) + min_p = json["min_p"].asFloat(); + if (json.isMember("tfs_z")) + tfs_z = json["tfs_z"].asFloat(); + if (json.isMember("typ_p")) + typ_p = json["typ_p"].asFloat(); + if (json.isMember("repeat_last_n")) + repeat_last_n = json["repeat_last_n"].asInt(); + if (json.isMember("repeat_penalty")) + repeat_penalty = json["repeat_penalty"].asFloat(); + if (json.isMember("mirostat")) + mirostat = json["mirostat"].asBool(); + if (json.isMember("mirostat_tau")) + mirostat_tau = json["mirostat_tau"].asFloat(); + if (json.isMember("mirostat_eta")) + mirostat_eta = json["mirostat_eta"].asFloat(); + if (json.isMember("penalize_nl")) + penalize_nl = json["penalize_nl"].asBool(); + if (json.isMember("ignore_eos")) + ignore_eos = json["ignore_eos"].asBool(); + if (json.isMember("n_probs")) + n_probs = json["n_probs"].asInt(); + if (json.isMember("min_keep")) + min_keep = json["min_keep"].asInt(); + if (json.isMember("ngl")) + ngl = json["ngl"].asInt(); + if (json.isMember("ctx_len")) + ctx_len = json["ctx_len"].asInt(); + if (json.isMember("engine")) + engine = json["engine"].asString(); + if (json.isMember("prompt_template")) + prompt_template = json["prompt_template"].asString(); + if (json.isMember("system_template")) + system_template = json["system_template"].asString(); + if (json.isMember("user_template")) + user_template = json["user_template"].asString(); + if (json.isMember("ai_template")) + ai_template = json["ai_template"].asString(); + if (json.isMember("os")) + os = json["os"].asString(); + if (json.isMember("gpu_arch")) + gpu_arch = json["gpu_arch"].asString(); + if (json.isMember("quantization_method")) + quantization_method = json["quantization_method"].asString(); + if (json.isMember("precision")) + precision = json["precision"].asString(); + + if (json.isMember("files") && json["files"].isArray()) { + files.clear(); + for (const auto& file : json["files"]) { + files.push_back(file.asString()); + } + } + + if (json.isMember("created")) + created = json["created"].asUInt64(); + if (json.isMember("object")) + object = json["object"].asString(); + if (json.isMember("owned_by")) + owned_by = json["owned_by"].asString(); + if (json.isMember("text_model")) + text_model = json["text_model"].asBool(); + + if (engine == "cortex.tensorrt-llm") { + if (json.isMember("trtllm_version")) + trtllm_version = json["trtllm_version"].asString(); + if (json.isMember("tp")) + tp = json["tp"].asInt(); + } + } Json::Value ToJson() const { Json::Value obj; diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index 6073cbbb3..74155a316 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -14,6 +14,7 @@ #include "commands/model_pull_cmd.h" #include "commands/model_start_cmd.h" #include "commands/model_stop_cmd.h" +#include "commands/model_upd_cmd.h" #include "commands/run_cmd.h" #include "commands/server_start_cmd.h" #include "commands/server_stop_cmd.h" @@ -256,10 +257,8 @@ void CommandLineParser::SetupModelCommands() { commands::ModelAliasCmd mdc; mdc.Exec(cml_data_.model_id, cml_data_.model_alias); }); - - auto model_update_cmd = - models_cmd->add_subcommand("update", "Update configuration of a model"); - model_update_cmd->group(kSubcommands); + // Model update parameters comment + ModelUpdate(models_cmd); std::string model_path; auto model_import_cmd = models_cmd->add_subcommand( @@ -373,6 +372,12 @@ void CommandLineParser::SetupSystemCommands() { update_cmd->group(kSystemGroup); update_cmd->add_option("-v", cml_data_.cortex_version, ""); update_cmd->callback([this] { +#if !defined(_WIN32) + if (getuid()) { + CLI_LOG("Error: Not root user. Please run with sudo."); + return; + } +#endif commands::CortexUpdCmd cuc; cuc.Exec(cml_data_.cortex_version); cml_data_.check_upd = false; @@ -442,3 +447,69 @@ void CommandLineParser::EngineGet(CLI::App* parent) { [engine_name] { commands::EngineGetCmd().Exec(engine_name); }); } } + +void CommandLineParser::ModelUpdate(CLI::App* parent) { + auto model_update_cmd = + parent->add_subcommand("update", "Update configuration of a model"); + model_update_cmd->group(kSubcommands); + model_update_cmd->add_option("--model_id", cml_data_.model_id, "Model ID") + ->required(); + + // Add options dynamically + std::vector option_names = {"name", + "model", + "version", + "stop", + "top_p", + "temperature", + "frequency_penalty", + "presence_penalty", + "max_tokens", + "stream", + "ngl", + "ctx_len", + "engine", + "prompt_template", + "system_template", + "user_template", + "ai_template", + "os", + "gpu_arch", + "quantization_method", + "precision", + "tp", + "trtllm_version", + "text_model", + "files", + "created", + "object", + "owned_by", + "seed", + "dynatemp_range", + "dynatemp_exponent", + "top_k", + "min_p", + "tfs_z", + "typ_p", + "repeat_last_n", + "repeat_penalty", + "mirostat", + "mirostat_tau", + "mirostat_eta", + "penalize_nl", + "ignore_eos", + "n_probs", + "min_keep", + "grammar"}; + + for (const auto& option_name : option_names) { + model_update_cmd->add_option("--" + option_name, + cml_data_.model_update_options[option_name], + option_name); + } + + model_update_cmd->callback([this]() { + commands::ModelUpdCmd command(cml_data_.model_id); + command.Exec(cml_data_.model_update_options); + }); +} \ No newline at end of file diff --git a/engine/controllers/command_line_parser.h b/engine/controllers/command_line_parser.h index 87a8063fd..aaa24e064 100644 --- a/engine/controllers/command_line_parser.h +++ b/engine/controllers/command_line_parser.h @@ -1,9 +1,9 @@ #pragma once #include "CLI/CLI.hpp" +#include "commands/model_upd_cmd.h" #include "services/engine_service.h" #include "utils/config_yaml_utils.h" - class CommandLineParser { public: CommandLineParser(); @@ -11,13 +11,13 @@ class CommandLineParser { private: void SetupCommonCommands(); - + void SetupInferenceCommands(); - + void SetupModelCommands(); - + void SetupEngineCommands(); - + void SetupSystemCommands(); void EngineInstall(CLI::App* parent, const std::string& engine_name, @@ -26,10 +26,11 @@ class CommandLineParser { void EngineUninstall(CLI::App* parent, const std::string& engine_name); void EngineGet(CLI::App* parent); + void ModelUpdate(CLI::App* parent); CLI::App app_; EngineService engine_service_; - struct CmlData{ + struct CmlData { std::string model_id; std::string msg; std::string model_alias; @@ -40,6 +41,7 @@ class CommandLineParser { bool check_upd = true; int port; config_yaml_utils::CortexConfig config; + std::unordered_map model_update_options; }; CmlData cml_data_; }; diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index e857d89da..4660b50e5 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -114,7 +114,7 @@ void Models::GetModel( auto model_config = yaml_handler.GetModelConfig(); Json::Value obj = model_config.ToJson(); - + data.append(std::move(obj)); ret["data"] = data; ret["result"] = "OK"; @@ -155,7 +155,49 @@ void Models::DeleteModel(const HttpRequestPtr& req, callback(resp); } } +void Models::UpdateModel( + const HttpRequestPtr& req, + std::function&& callback) const { + if (!http_util::HasFieldInReq(req, callback, "modelId")) { + return; + } + auto model_id = (*(req->getJsonObject())).get("modelId", "").asString(); + auto json_body = *(req->getJsonObject()); + try { + modellist_utils::ModelListUtils model_list_utils; + auto model_entry = model_list_utils.GetModelInfo(model_id); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + config::ModelConfig model_config = yaml_handler.GetModelConfig(); + model_config.FromJson(json_body); + yaml_handler.UpdateModelConfig(model_config); + yaml_handler.WriteYamlFile(model_entry.path_to_model_yaml); + std::string message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + LOG_INFO << message; + Json::Value ret; + ret["result"] = "Updated successfully!"; + ret["modelHandle"] = model_id; + ret["message"] = message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + + } catch (const std::exception& e) { + std::string error_message = + "Error updating with model_id '" + model_id + "': " + e.what(); + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Updated failed!"; + ret["modelHandle"] = model_id; + ret["message"] = error_message; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } +} void Models::ImportModel( const HttpRequestPtr& req, std::function&& callback) const { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 4ae1ff41f..8d652c86a 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -15,6 +15,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::PullModel, "/pull", Post); METHOD_ADD(Models::ListModel, "/list", Get); METHOD_ADD(Models::GetModel, "/get", Post); + METHOD_ADD(Models::UpdateModel, "/update/", Post); METHOD_ADD(Models::ImportModel, "/import", Post); METHOD_ADD(Models::DeleteModel, "/{1}", Delete); METHOD_ADD(Models::SetModelAlias, "/alias", Post); @@ -26,8 +27,11 @@ class Models : public drogon::HttpController { std::function&& callback) const; void GetModel(const HttpRequestPtr& req, std::function&& callback) const; - void ImportModel(const HttpRequestPtr& req, - std::function&& callback) const; + void UpdateModel(const HttpRequestPtr& req, + std::function&& callback) const; + void ImportModel( + const HttpRequestPtr& req, + std::function&& callback) const; void DeleteModel(const HttpRequestPtr& req, std::function&& callback, const std::string& model_id) const; diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index 496d01116..e3754fa76 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -12,6 +12,14 @@ #include "utils/format_utils.h" #include "utils/logging_utils.h" +#ifdef _WIN32 +#define ftell64(f) _ftelli64(f) +#define fseek64(f, o, w) _fseeki64(f, o, w) +#else +#define ftell64(f) ftello(f) +#define fseek64(f, o, w) fseeko(f, o, w) +#endif + namespace { size_t WriteCallback(void* ptr, size_t size, size_t nmemb, FILE* stream) { size_t written = fwrite(ptr, size, nmemb, stream); @@ -37,12 +45,19 @@ void DownloadService::AddDownloadTask( } // all items are valid, start downloading + bool download_successfully = true; for (const auto& item : task.items) { CLI_LOG("Start downloading: " + item.localPath.filename().string()); - Download(task.id, item, true); + try { + Download(task.id, item, true); + } catch (const std::runtime_error& e) { + CTL_ERR("Failed to download: " << item.downloadUrl << " - " << e.what()); + download_successfully = false; + break; + } } - if (callback.has_value()) { + if (download_successfully && callback.has_value()) { callback.value()(task); } } @@ -102,10 +117,15 @@ void DownloadService::Download(const std::string& download_id, std::string mode = "wb"; if (allow_resume && std::filesystem::exists(download_item.localPath) && download_item.bytes.has_value()) { - FILE* existing_file = fopen(download_item.localPath.string().c_str(), "r"); - fseek(existing_file, 0, SEEK_END); - curl_off_t existing_file_size = ftell(existing_file); - fclose(existing_file); + curl_off_t existing_file_size = GetLocalFileSize(download_item.localPath); + if (existing_file_size == -1) { + CLI_LOG("Cannot get file size: " << download_item.localPath.string() + << " . Start download over!"); + return; + } + CTL_INF("Existing file size: " << download_item.downloadUrl << " - " + << download_item.localPath.string() << " - " + << existing_file_size); auto missing_bytes = download_item.bytes.value() - existing_file_size; if (missing_bytes > 0) { CLI_LOG("Found unfinished download! Additional " @@ -149,9 +169,13 @@ void DownloadService::Download(const std::string& download_id, curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); if (mode == "ab") { - fseek(file, 0, SEEK_END); - curl_off_t local_file_size = ftell(file); - curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, local_file_size); + auto local_file_size = GetLocalFileSize(download_item.localPath); + if (local_file_size != -1) { + curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, + GetLocalFileSize(download_item.localPath)); + } else { + CTL_ERR("Cannot get file size: " << download_item.localPath.string()); + } } res = curl_easy_perform(curl); @@ -159,8 +183,26 @@ void DownloadService::Download(const std::string& download_id, if (res != CURLE_OK) { fprintf(stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); + throw std::runtime_error("Failed to download file " + + download_item.localPath.filename().string()); } fclose(file); curl_easy_cleanup(curl); } + +curl_off_t DownloadService::GetLocalFileSize( + const std::filesystem::path& path) const { + FILE* file = fopen(path.string().c_str(), "r"); + if (!file) { + return -1; + } + + if (fseek64(file, 0, SEEK_END) != 0) { + return -1; + } + + curl_off_t file_size = ftell64(file); + fclose(file); + return file_size; +} \ No newline at end of file diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 7063be74c..b9f93ee82 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -73,4 +74,6 @@ class DownloadService { private: void Download(const std::string& download_id, const DownloadItem& download_item, bool allow_resume); + + curl_off_t GetLocalFileSize(const std::filesystem::path& path) const; }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 1b1f1d278..289bebd68 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -119,9 +119,12 @@ void EngineService::UnzipEngine(const std::string& engine, CTL_INF("engine: " << engine); CTL_INF("CUDA version: " << hw_inf_.cuda_driver_version); std::string cuda_variant = "cuda-"; - cuda_variant += GetSuitableCudaVersion(engine, hw_inf_.cuda_driver_version) + - "-" + hw_inf_.sys_inf->os + "-" + hw_inf_.sys_inf->arch + - ".tar.gz"; + auto cuda_github = + GetSuitableCudaVersion(engine, hw_inf_.cuda_driver_version); + // Github release cuda example: cuda-12-0-windows-amd64.tar.gz + std::replace(cuda_github.begin(), cuda_github.end(), '.', '-'); + cuda_variant += cuda_github + "-" + hw_inf_.sys_inf->os + "-" + + hw_inf_.sys_inf->arch + ".tar.gz"; CTL_INF("cuda_variant: " << cuda_variant); std::vector variants; diff --git a/engine/test/components/test_modellist_utils.cc b/engine/test/components/test_modellist_utils.cc index 68b06483d..d1dbf91e3 100644 --- a/engine/test/components/test_modellist_utils.cc +++ b/engine/test/components/test_modellist_utils.cc @@ -19,6 +19,7 @@ class ModelListUtilsTestSuite : public ::testing::Test { void TearDown() { // Clean up the temporary directory + std::remove((file_manager_utils::GetModelsContainerPath() / "model.list").string().c_str()); } TEST_F(ModelListUtilsTestSuite, TestAddModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel)); @@ -128,4 +129,6 @@ TEST_F(ModelListUtilsTestSuite, TestHasModel) { EXPECT_TRUE(model_list_.HasModel("test_model_id")); EXPECT_TRUE(model_list_.HasModel("test_alias")); EXPECT_FALSE(model_list_.HasModel("non_existent_model")); + // Clean up + model_list_.DeleteModelEntry("test_model_id"); } \ No newline at end of file From a4109dd930d03c50e1fced243c216087bf603076 Mon Sep 17 00:00:00 2001 From: Thuandz Date: Tue, 24 Sep 2024 13:21:07 +0700 Subject: [PATCH 6/6] Add more fields to handle when update --- engine/commands/model_upd_cmd.cc | 175 ++++++++++++++++++++++++++++++- 1 file changed, 174 insertions(+), 1 deletion(-) diff --git a/engine/commands/model_upd_cmd.cc b/engine/commands/model_upd_cmd.cc index eb7edd3df..65883def3 100644 --- a/engine/commands/model_upd_cmd.cc +++ b/engine/commands/model_upd_cmd.cc @@ -47,7 +47,60 @@ void ModelUpdCmd::UpdateConfig(const std::string& key, [](ModelUpdCmd* self, const std::string&, const std::string& v) { self->model_config_.version = v; }}, + {"engine", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.engine = v; + }}, + {"prompt_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.prompt_template = v; + }}, + {"system_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.system_template = v; + }}, + {"user_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.user_template = v; + }}, + {"ai_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.ai_template = v; + }}, + {"os", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.os = v; + }}, + {"gpu_arch", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.gpu_arch = v; + }}, + {"quantization_method", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.quantization_method = v; + }}, + {"precision", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.precision = v; + }}, + {"trtllm_version", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.trtllm_version = v; + }}, + {"object", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.object = v; + }}, + {"owned_by", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.owned_by = v; + }}, + {"grammar", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.grammar = v; + }}, {"stop", &ModelUpdCmd::UpdateVectorField}, + {"files", &ModelUpdCmd::UpdateVectorField}, {"top_p", [](ModelUpdCmd* self, const std::string& k, const std::string& v) { self->UpdateNumericField( @@ -71,23 +124,143 @@ void ModelUpdCmd::UpdateConfig(const std::string& key, self->model_config_.presence_penalty = f; }); }}, + {"dynatemp_range", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.dynatemp_range = f; + }); + }}, + {"dynatemp_exponent", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.dynatemp_exponent = f; + }); + }}, + {"min_p", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.min_p = f; }); + }}, + {"tfs_z", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.tfs_z = f; }); + }}, + {"typ_p", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.typ_p = f; }); + }}, + {"repeat_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.repeat_penalty = f; + }); + }}, + {"mirostat_tau", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.mirostat_tau = f; + }); + }}, + {"mirostat_eta", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.mirostat_eta = f; + }); + }}, {"max_tokens", [](ModelUpdCmd* self, const std::string& k, const std::string& v) { self->UpdateNumericField(k, v, [self](float f) { self->model_config_.max_tokens = static_cast(f); }); }}, + {"ngl", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.ngl = static_cast(f); + }); + }}, + {"ctx_len", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.ctx_len = static_cast(f); + }); + }}, + {"tp", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.tp = static_cast(f); + }); + }}, + {"seed", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.seed = static_cast(f); + }); + }}, + {"top_k", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.top_k = static_cast(f); + }); + }}, + {"repeat_last_n", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.repeat_last_n = static_cast(f); + }); + }}, + {"n_probs", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.n_probs = static_cast(f); + }); + }}, + {"min_keep", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.min_keep = static_cast(f); + }); + }}, {"stream", [](ModelUpdCmd* self, const std::string& k, const std::string& v) { self->UpdateBooleanField( k, v, [self](bool b) { self->model_config_.stream = b; }); }}, - // Add more fields here... + {"text_model", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.text_model = b; }); + }}, + {"mirostat", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.mirostat = b; }); + }}, + {"penalize_nl", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.penalize_nl = b; }); + }}, + {"ignore_eos", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.ignore_eos = b; }); + }}, + {"created", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.created = static_cast(f); + }); + }}, }; if (auto it = updaters.find(key); it != updaters.end()) { it->second(this, key, value); LogUpdate(key, value); + } else { + CLI_LOG("Warning: Unknown configuration key '" << key << "' ignored."); } }