Skip to content

Commit

Permalink
remove validate task
Browse files Browse the repository at this point in the history
update

update

update

update

update
  • Loading branch information
namchuai committed Oct 29, 2024
1 parent 3be991e commit 3edfb66
Show file tree
Hide file tree
Showing 20 changed files with 263 additions and 359 deletions.
5 changes: 1 addition & 4 deletions engine/common/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class BaseModel {
virtual void GetModels(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
virtual void GetEngines(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
virtual void FineTuning(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) = 0;
Expand All @@ -48,4 +45,4 @@ class BaseEmbedding {
std::function<void(const HttpResponsePtr&)>&& callback) = 0;

// The derived class can also override other methods if needed
};
};
3 changes: 0 additions & 3 deletions engine/common/download_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
#include <filesystem>
#include <sstream>
#include <string>
#include <unordered_map>

enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex };

struct DownloadItem {

std::string id;

std::optional<std::unordered_map<std::string, std::string>> headers;

std::string downloadUrl;

/**
Expand Down
108 changes: 60 additions & 48 deletions engine/controllers/engines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#include "services/engine_service.h"
#include "utils/archive_utils.h"
#include "utils/cortex_utils.h"
#include "utils/engine_constants.h"
#include "utils/logging_utils.h"
#include "utils/string_utils.h"

namespace {
// Need to change this after we rename repositories
// TODO: namh try to remove this
std::string NormalizeEngine(const std::string& engine) {
if (engine == kLlamaEngine) {
return kLlamaRepo;
Expand Down Expand Up @@ -55,65 +55,38 @@ void Engines::InstallEngine(
void Engines::ListEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const {
// TODO: NamH need refactor this
// auto status_list = engine_service_->GetEngineInfoList();

std::vector<std::string> supported_engines{kLlamaEngine, kOnnxEngine,
kTrtLlmEngine};
Json::Value ret;
ret["object"] = "list";
// Json::Value data(Json::arrayValue);
// for (auto& status : status_list) {
// Json::Value ret;
// ret["name"] = status.name;
// ret["description"] = status.description;
// ret["version"] = status.version.value_or("");
// ret["variant"] = status.variant.value_or("");
// ret["productName"] = status.product_name;
// ret["status"] = status.status;
// ret["format"] = status.format;
//
// data.append(std::move(ret));
// }
for (const auto& engine : supported_engines) {
std::cout << engine << std::endl;

// ret["data"] = data;
ret["result"] = "OK";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k200OK);
callback(resp);
}
auto result = engine_service_->GetInstalledEngineVariants(engine);
if (result.has_error()) {
continue;
}
Json::Value variants(Json::arrayValue);
for (const auto& variant : result.value()) {
variants.append(variant.ToJson());
}
ret[engine] = variants;
}

void Engines::GetEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) const {
// auto status = engine_service_->GetEngineInfo(engine);
Json::Value ret;
// if (status.has_value()) {
// ret["name"] = status->name;
// ret["description"] = status->description;
// ret["version"] = status->version.value_or("");
// ret["variant"] = status->variant.value_or("");
// ret["productName"] = status->product_name;
// ret["status"] = status->status;
// ret["format"] = status->format;
//
// auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
// resp->setStatusCode(k200OK);
// callback(resp);
// } else {
ret["message"] = "Engine not found";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(k400BadRequest);
resp->setStatusCode(k200OK);
callback(resp);
// }
}

void Engines::UninstallEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) {
const std::string& engine, const std::string& version,
const std::string& variant) {

auto result = engine_service_->UninstallEngine(engine);
Json::Value ret;
auto result =
engine_service_->UninstallEngineVariant(engine, variant, version);

Json::Value ret;
if (result.has_error()) {
CTL_INF(result.error());
ret["message"] = result.error();
Expand Down Expand Up @@ -309,3 +282,42 @@ void Engines::GetDefaultEngineVariant(
callback(resp);
}
}

void Engines::LoadEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) {
auto result = engine_service_->LoadEngine(engine);
if (result.has_error()) {
Json::Value res;
res["message"] = result.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k400BadRequest);
callback(resp);
} else {
Json::Value res;
res["message"] = "Engine " + engine + " loaded successfully!";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k200OK);
callback(resp);
}
}

void Engines::UnloadEngine(
const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) {
auto result = engine_service_->UnloadEngine(engine);
if (result.has_error()) {
Json::Value res;
res["message"] = result.error();
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k400BadRequest);
callback(resp);
} else {
Json::Value res;
res["message"] = "Engine " + engine + " unloaded successfully!";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k200OK);
callback(resp);
}
}
25 changes: 15 additions & 10 deletions engine/controllers/engines.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ class Engines : public drogon::HttpController<Engines, false> {

// TODO: update this API
METHOD_ADD(Engines::InstallEngine, "/install/{1}", Post);
METHOD_ADD(Engines::UninstallEngine, "/{1}", Delete);
METHOD_ADD(Engines::UninstallEngine, "/{1}/{2}/{3}", Delete);
METHOD_ADD(Engines::ListEngine, "", Get);

// TODO: might better use query param
METHOD_ADD(Engines::GetEngineVersions, "/{1}/versions", Get);
METHOD_ADD(Engines::GetEngineVariants, "/{1}/versions/{2}", Get);
METHOD_ADD(Engines::InstallEngineVariant, "/{1}/versions/{2}/{3}", Post);
Expand All @@ -28,10 +27,11 @@ class Engines : public drogon::HttpController<Engines, false> {
METHOD_ADD(Engines::SetDefaultEngineVariant, "/{1}/default/{2}/{3}", Post);
METHOD_ADD(Engines::GetDefaultEngineVariant, "/{1}/default", Get);

METHOD_ADD(Engines::LoadEngine, "/{1}/load", Post);
METHOD_ADD(Engines::UnloadEngine, "/{1}/load", Delete);

ADD_METHOD_TO(Engines::InstallEngine, "/v1/engines/install/{1}", Post);
ADD_METHOD_TO(Engines::UninstallEngine, "/v1/engines/{1}", Delete);
// TODO: update response of this API
ADD_METHOD_TO(Engines::ListEngine, "/v1/engines", Get);
ADD_METHOD_TO(Engines::UninstallEngine, "/v1/engines/{1}/{2}/{3}", Delete);

METHOD_LIST_END

Expand All @@ -45,13 +45,10 @@ class Engines : public drogon::HttpController<Engines, false> {
void ListEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback) const;

void GetEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) const;

void UninstallEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine);
const std::string& engine, const std::string& version,
const std::string& variant);

void GetEngineVersions(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
Expand Down Expand Up @@ -93,6 +90,14 @@ class Engines : public drogon::HttpController<Engines, false> {
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine) const;

void LoadEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine);

void UnloadEngine(const HttpRequestPtr& req,
std::function<void(const HttpResponsePtr&)>&& callback,
const std::string& engine);

private:
std::shared_ptr<EngineService> engine_service_;
};
3 changes: 2 additions & 1 deletion engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ void RunServer(std::optional<int> port) {

auto event_queue_ptr = std::make_shared<EventQueue>();
cortex::event::EventProcessor event_processor(event_queue_ptr);
auto inference_svc = std::make_shared<services::InferenceService>();

auto download_service = std::make_shared<DownloadService>(event_queue_ptr);
auto engine_service = std::make_shared<EngineService>(download_service);
auto inference_svc =
std::make_shared<services::InferenceService>(engine_service);
auto model_service =
std::make_shared<ModelService>(download_service, inference_svc);

Expand Down
60 changes: 11 additions & 49 deletions engine/services/download_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ostream>
#include <utility>
#include "download_service.h"
#include "utils/curl_utils.h"
#include "utils/format_utils.h"
#include "utils/logging_utils.h"
#include "utils/result.hpp"
Expand All @@ -27,43 +28,9 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) {
}
} // namespace

cpp::result<void, std::string> DownloadService::VerifyDownloadTask(
DownloadTask& task) const noexcept {
CLI_LOG("Validating download items, please wait..");

auto total_download_size{0};
std::optional<std::string> err_msg = std::nullopt;

for (auto& item : task.items) {
auto file_size = GetFileSize(item.downloadUrl);
if (file_size.has_error()) {
err_msg = file_size.error();
break;
}

item.bytes = file_size.value();
total_download_size += file_size.value();
}

if (err_msg.has_value()) {
CTL_ERR(err_msg.value());
return cpp::fail(err_msg.value());
}

return {};
}

cpp::result<bool, std::string> DownloadService::AddDownloadTask(
DownloadTask& task,
std::optional<OnDownloadTaskSuccessfully> callback) noexcept {
auto validating_result = VerifyDownloadTask(task);
if (validating_result.has_error()) {
return cpp::fail(validating_result.error());
}

// all items are valid, start downloading
// if any item from the task failed to download, the whole task will be
// considered failed
std::optional<std::string> dl_err_msg = std::nullopt;
bool has_task_done = false;
for (const auto& item : task.items) {
Expand All @@ -87,10 +54,7 @@ cpp::result<bool, std::string> DownloadService::AddDownloadTask(
}

cpp::result<uint64_t, std::string> DownloadService::GetFileSize(
const std::string& url,
const std::optional<
std::reference_wrapper<std::unordered_map<std::string, std::string>>>&
headers) const noexcept {
const std::string& url) const noexcept {

auto curl = curl_easy_init();
if (!curl) {
Expand All @@ -101,10 +65,11 @@ cpp::result<uint64_t, std::string> DownloadService::GetFileSize(
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L);
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());

auto headers = curl_utils::GetHeaders(url);
if (headers.has_value()) {
curl_slist* curl_headers = nullptr;

for (const auto& [key, value] : headers->get()) {
for (const auto& [key, value] : headers.value()) {
auto header = key + ": " + value;
curl_headers = curl_slist_append(curl_headers, header.c_str());
}
Expand Down Expand Up @@ -185,10 +150,11 @@ cpp::result<bool, std::string> DownloadService::Download(
}

curl_easy_setopt(curl, CURLOPT_URL, download_item.downloadUrl.c_str());
if (download_item.headers.has_value()) {
auto headers = curl_utils::GetHeaders(download_item.downloadUrl);
if (headers.has_value()) {
curl_slist* curl_headers = nullptr;

for (const auto& [key, value] : download_item.headers.value()) {
for (const auto& [key, value] : headers.value()) {
auto header = key + ": " + value;
curl_headers = curl_slist_append(curl_headers, header.c_str());
}
Expand Down Expand Up @@ -262,7 +228,7 @@ void DownloadService::ProcessTask(DownloadTask& task) {

active_task_ = std::make_shared<DownloadTask>(task);

for (auto& item : task.items) {
for (const auto& item : task.items) {
auto handle = curl_easy_init();
if (handle == nullptr) {
// skip the task
Expand All @@ -282,10 +248,11 @@ void DownloadService::ProcessTask(DownloadTask& task) {
});
downloading_data_map_.insert(std::make_pair(item.id, dl_data_ptr));

if (item.headers.has_value()) {
auto headers = curl_utils::GetHeaders(item.downloadUrl);
if (headers.has_value()) {
curl_slist* curl_headers = nullptr;

for (const auto& [key, value] : item.headers.value()) {
for (const auto& [key, value] : headers.value()) {
auto header = key + ": " + value;
curl_headers = curl_slist_append(curl_headers, header.c_str());
}
Expand Down Expand Up @@ -408,11 +375,6 @@ void DownloadService::ProcessCompletedTransfers() {

cpp::result<DownloadTask, std::string> DownloadService::AddTask(
DownloadTask& task, std::function<void(const DownloadTask&)> callback) {
auto validate_result = VerifyDownloadTask(task);
if (validate_result.has_error()) {
return cpp::fail(validate_result.error());
}

{
std::lock_guard<std::mutex> lock(callbacks_mutex_);
callbacks_[task.id] = std::move(callback);
Expand Down
8 changes: 1 addition & 7 deletions engine/services/download_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ class DownloadService {
* @param url - url to get file size
*/
cpp::result<uint64_t, std::string> GetFileSize(
const std::string& url,
const std::optional<
std::reference_wrapper<std::unordered_map<std::string, std::string>>>&
headers = std::nullopt) const noexcept;
const std::string& url) const noexcept;

cpp::result<std::string, std::string> StopTask(const std::string& task_id);

Expand All @@ -86,9 +83,6 @@ class DownloadService {
DownloadService* download_service;
};

cpp::result<void, std::string> VerifyDownloadTask(
DownloadTask& task) const noexcept;

cpp::result<bool, std::string> Download(
const std::string& download_id,
const DownloadItem& download_item) noexcept;
Expand Down
Loading

0 comments on commit 3edfb66

Please sign in to comment.