Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ void Models::PullModel(const HttpRequestPtr& req,
model_handle, desired_model_id, desired_model_name);
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
if (model_and_branch.size() == 3) {
auto mh = url_parser::Url{
.protocol = "https",
.host = kHuggingFaceHost,
.pathParams = {
model_and_branch[0],
model_and_branch[1],
"resolve",
"main",
model_and_branch[2],
}}.ToFullPath();
return model_service_->HandleDownloadUrlAsync(mh, desired_model_id,
desired_model_name);
}
return model_service_->DownloadModelFromCortexsoAsync(
model_and_branch[0], model_and_branch[1], desired_model_id);
}
Expand Down Expand Up @@ -813,15 +827,34 @@ void Models::GetModelSources(
resp->setStatusCode(k400BadRequest);
callback(resp);
} else {
auto const& info = res.value();
auto& info = res.value();
Json::Value ret;
Json::Value data(Json::arrayValue);
for (auto const& i : info) {
data.append(i);
for (auto& i : info) {
data.append(i.second.ToJson());
}
ret["data"] = data;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}
}

void Models::GetModelSource(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& src) {
auto res = model_src_svc_->GetModelSource(src);
if (res.has_error()) {
Json::Value ret;
ret["message"] = res.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
callback(resp);
} else {
auto& info = res.value();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(info.ToJson());
resp->setStatusCode(k200OK);
callback(resp);
}
}
5 changes: 5 additions & 0 deletions engine/controllers/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Models : public drogon::HttpController<Models, false> {
ADD_METHOD_TO(Models::AddModelSource, "/v1/models/sources", Post);
ADD_METHOD_TO(Models::DeleteModelSource, "/v1/models/sources", Delete);
ADD_METHOD_TO(Models::GetModelSources, "/v1/models/sources", Get);
ADD_METHOD_TO(Models::GetModelSource, "/v1/models/sources/{src}", Get);
METHOD_LIST_END

explicit Models(std::shared_ptr<DatabaseService> db_service,
Expand Down Expand Up @@ -106,6 +107,10 @@ class Models : public drogon::HttpController<Models, false> {
void GetModelSources(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback);

void GetModelSource(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& src);

private:
std::shared_ptr<DatabaseService> db_service_;
std::shared_ptr<ModelService> model_service_;
Expand Down
62 changes: 45 additions & 17 deletions engine/database/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,35 +270,63 @@ bool Models::HasModel(const std::string& identifier) const {
}
}

cpp::result<std::vector<std::string>, std::string> Models::GetModelSources()
const {
cpp::result<std::vector<ModelEntry>, std::string> Models::GetModels(
const std::string& model_src) const {
try {
std::vector<std::string> sources;
std::vector<ModelEntry> res;
SQLite::Statement query(db_,
"SELECT DISTINCT model_source FROM models WHERE "
"status = \"downloadable\"");

"SELECT model_id, author_repo_id, branch_name, "
"path_to_model_yaml, model_alias, model_format, "
"model_source, status, engine, metadata FROM "
"models WHERE model_source = "
"? AND status = \"downloadable\"");
query.bind(1, model_src);
while (query.executeStep()) {
sources.push_back(query.getColumn(0).getString());
ModelEntry entry;
entry.model = query.getColumn(0).getString();
entry.author_repo_id = query.getColumn(1).getString();
entry.branch_name = query.getColumn(2).getString();
entry.path_to_model_yaml = query.getColumn(3).getString();
entry.model_alias = query.getColumn(4).getString();
entry.model_format = query.getColumn(5).getString();
entry.model_source = query.getColumn(6).getString();
entry.status = StringToStatus(query.getColumn(7).getString());
entry.engine = query.getColumn(8).getString();
entry.metadata = query.getColumn(9).getString();
res.push_back(entry);
}
return sources;
return res;
} catch (const std::exception& e) {
return cpp::fail(e.what());
}
}

cpp::result<std::vector<std::string>, std::string> Models::GetModels(
const std::string& model_src) const {
cpp::result<std::vector<ModelEntry>, std::string> Models::GetModelSources()
const {
try {
std::vector<std::string> ids;
SQLite::Statement query(db_,
"SELECT model_id FROM models WHERE model_source = "
"? AND status = \"downloadable\"");
query.bind(1, model_src);
std::vector<ModelEntry> res;
SQLite::Statement query(
db_,
"SELECT model_id, author_repo_id, branch_name, "
"path_to_model_yaml, model_alias, model_format, "
"model_source, status, engine, metadata FROM models "
"WHERE model_source != \"\" AND (status = \"downloaded\" OR status = "
"\"downloadable\")");
while (query.executeStep()) {
ids.push_back(query.getColumn(0).getString());
ModelEntry entry;
entry.model = query.getColumn(0).getString();
entry.author_repo_id = query.getColumn(1).getString();
entry.branch_name = query.getColumn(2).getString();
entry.path_to_model_yaml = query.getColumn(3).getString();
entry.model_alias = query.getColumn(4).getString();
entry.model_format = query.getColumn(5).getString();
entry.model_source = query.getColumn(6).getString();
entry.status = StringToStatus(query.getColumn(7).getString());
entry.engine = query.getColumn(8).getString();
entry.metadata = query.getColumn(9).getString();
res.push_back(entry);
}
return ids;
return res;
} catch (const std::exception& e) {
return cpp::fail(e.what());
}
Expand Down
5 changes: 2 additions & 3 deletions engine/database/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ namespace cortex::db {

enum class ModelStatus { Remote, Downloaded, Downloadable };


struct ModelEntry {
std::string model;
std::string author_repo_id;
Expand Down Expand Up @@ -57,9 +56,9 @@ class Models {
cpp::result<std::vector<std::string>, std::string> FindRelatedModel(
const std::string& identifier) const;
bool HasModel(const std::string& identifier) const;
cpp::result<std::vector<std::string>, std::string> GetModelSources() const;
cpp::result<std::vector<std::string>, std::string> GetModels(
cpp::result<std::vector<ModelEntry>, std::string> GetModels(
const std::string& model_src) const;
cpp::result<std::vector<ModelEntry>, std::string> GetModelSources() const;
};

} // namespace cortex::db
12 changes: 6 additions & 6 deletions engine/services/database_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ bool DatabaseService::HasModel(const std::string& identifier) const {
return cortex::db::Models().HasModel(identifier);
}

cpp::result<std::vector<std::string>, std::string>
DatabaseService::GetModelSources() const {
return cortex::db::Models().GetModelSources();
}

cpp::result<std::vector<std::string>, std::string> DatabaseService::GetModels(
cpp::result<std::vector<ModelEntry>, std::string> DatabaseService::GetModels(
const std::string& model_src) const {
return cortex::db::Models().GetModels(model_src);
}

cpp::result<std::vector<ModelEntry>, std::string>
DatabaseService::GetModelSources() const {
return cortex::db::Models().GetModelSources();
}
// end models
5 changes: 3 additions & 2 deletions engine/services/database_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ class DatabaseService {
cpp::result<std::vector<std::string>, std::string> FindRelatedModel(
const std::string& identifier) const;
bool HasModel(const std::string& identifier) const;
cpp::result<std::vector<std::string>, std::string> GetModelSources() const;
cpp::result<std::vector<std::string>, std::string> GetModels(
cpp::result<std::vector<ModelEntry>, std::string> GetModels(
const std::string& model_src) const;
cpp::result<std::vector<ModelEntry>, std::string> GetModelSources()
const;

private:
};
51 changes: 11 additions & 40 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ void ParseGguf(DatabaseService& db_service,
CTL_INF("Adding model to modellist with branch: " << branch);

auto rel = file_manager_utils::ToRelativeCortexDataPath(yaml_name);
CTL_INF("path_to_model_yaml: " << rel.string());
CTL_INF("path_to_model_yaml: " << rel.string()
<< ", model: " << ggufDownloadItem.id);

auto author_id = author.has_value() ? author.value() : "cortexso";
if (!db_service.HasModel(ggufDownloadItem.id)) {
Expand All @@ -86,6 +87,7 @@ void ParseGguf(DatabaseService& db_service,
} else {
if (auto m = db_service.GetModelInfo(ggufDownloadItem.id); m.has_value()) {
auto upd_m = m.value();
upd_m.path_to_model_yaml = rel.string();
upd_m.status = cortex::db::ModelStatus::Downloaded;
if (auto r = db_service.UpdateModelEntry(ggufDownloadItem.id, upd_m);
r.has_error()) {
Expand Down Expand Up @@ -161,6 +163,9 @@ void ModelService::ForceIndexingModelList() {
continue;
}
try {
CTL_DBG(fmu::ToAbsoluteCortexDataPath(
fs::path(model_entry.path_to_model_yaml))
.string());
yaml_handler.ModelConfigFromFile(
fmu::ToAbsoluteCortexDataPath(
fs::path(model_entry.path_to_model_yaml))
Expand All @@ -171,48 +176,12 @@ void ModelService::ForceIndexingModelList() {
} catch (const std::exception& e) {
// remove in db
auto remove_result = db_service_->DeleteModelEntry(model_entry.model);
CTL_DBG(e.what());
// silently ignore result
}
}
}

cpp::result<std::string, std::string> ModelService::DownloadModel(
const std::string& input) {
if (input.empty()) {
return cpp::fail(
"Input must be Cortex Model Hub handle or HuggingFace url!");
}

if (string_utils::StartsWith(input, "https://")) {
return HandleUrl(input);
}

if (input.find(":") != std::string::npos) {
auto parsed = string_utils::SplitBy(input, ":");
if (parsed.size() != 2) {
return cpp::fail("Invalid model handle: " + input);
}
return DownloadModelFromCortexso(parsed[0], parsed[1]);
}

if (input.find("/") != std::string::npos) {
auto parsed = string_utils::SplitBy(input, "/");
if (parsed.size() != 2) {
return cpp::fail("Invalid model handle: " + input);
}

auto author = parsed[0];
auto model_name = parsed[1];
if (author == "cortexso") {
return HandleCortexsoModel(model_name);
}

return DownloadHuggingFaceGgufModel(author, model_name, std::nullopt);
}

return HandleCortexsoModel(input);
}

cpp::result<std::string, std::string> ModelService::HandleCortexsoModel(
const std::string& modelName) {
auto branches =
Expand Down Expand Up @@ -612,7 +581,8 @@ ModelService::DownloadModelFromCortexsoAsync(
.branch_name = branch,
.path_to_model_yaml = rel.string(),
.model_alias = unique_model_id,
.status = cortex::db::ModelStatus::Downloaded};
.status = cortex::db::ModelStatus::Downloaded,
.engine = mc.engine};
auto result = db_service_->AddModelEntry(model_entry);

if (result.has_error()) {
Expand All @@ -621,6 +591,7 @@ ModelService::DownloadModelFromCortexsoAsync(
} else {
if (auto m = db_service_->GetModelInfo(unique_model_id); m.has_value()) {
auto upd_m = m.value();
upd_m.path_to_model_yaml = rel.string();
upd_m.status = cortex::db::ModelStatus::Downloaded;
if (auto r = db_service_->UpdateModelEntry(unique_model_id, upd_m);
r.has_error()) {
Expand Down Expand Up @@ -1157,7 +1128,7 @@ cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(

if (input.find(":") != std::string::npos) {
auto parsed = string_utils::SplitBy(input, ":");
if (parsed.size() != 2) {
if (parsed.size() != 2 && parsed.size() != 3) {
return cpp::fail("Invalid model handle: " + input);
}
return ModelPullInfo{.id = input,
Expand Down
5 changes: 0 additions & 5 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ class ModelService {
inference_svc_(inference_service),
engine_svc_(engine_svc) {};

/**
* Return model id if download successfully
*/
cpp::result<std::string, std::string> DownloadModel(const std::string& input);

cpp::result<std::string, std::string> AbortDownloadModel(
const std::string& task_id);

Expand Down
Loading
Loading