Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
2 changes: 1 addition & 1 deletion .github/workflows/template-build-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ jobs:
with:
upload_url: ${{ inputs.upload_url }}
asset_path: ./engine/cortex.tar.gz
asset_name: cortex-${{ inputs.new_version }}-linux${{ inputs.arch }}.tar.gz
asset_name: cortex-${{ inputs.new_version }}-linux-${{ inputs.arch }}.tar.gz
asset_content_type: application/zip

- name: Upload release assert if public provider is github
Expand Down
4 changes: 2 additions & 2 deletions engine/controllers/engines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ void Engines::InstallRemoteEngine(
.get("get_models_url", "")
.asString();

if (engine.empty() || type.empty() || url.empty()) {
if (engine.empty() || type.empty()) {
Json::Value res;
res["message"] = "Engine name, type, url are required";
res["message"] = "Engine name, type are required";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(res);
resp->setStatusCode(k400BadRequest);
callback(resp);
Expand Down
1 change: 1 addition & 0 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ void Models::StartModel(
if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) {
auto model_path = o.asString();
if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) {
mp = model_path;
// Bypass if model does not exist in DB and llama_model_path exists
if (std::filesystem::exists(model_path) &&
!model_service_->HasModel(model_handle)) {
Expand Down
28 changes: 23 additions & 5 deletions engine/extensions/remote-engine/remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
auto* context = static_cast<StreamContext*>(userdata);
std::string chunk(ptr, size * nmemb);
CTL_DBG(chunk);
auto check_error = json_helper::ParseJsonString(chunk);
if (check_error.isMember("error")) {
Json::Value check_error;
Json::Reader reader;
if (reader.parse(chunk, check_error)) {
CTL_WRN(chunk);
Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = k400BadRequest;
context->need_stop = false;
(*context->callback)(std::move(status), std::move(check_error));
return size * nmemb;
}
Expand All @@ -58,7 +60,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = 200;
status["status_code"] = k200OK;
context->need_stop = false;
(*context->callback)(std::move(status), Json::Value());
break;
}
Expand Down Expand Up @@ -169,6 +172,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(

curl_slist_free_all(headers);
curl_easy_cleanup(curl);
if (context.need_stop) {
CTL_DBG("No stop message received, need to stop");
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
(*context.callback)(std::move(status), Json::Value());
}
return response;
}

Expand Down Expand Up @@ -602,6 +614,7 @@ void RemoteEngine::HandleChatCompletion(
status["status_code"] = k500InternalServerError;
Json::Value error;
error["error"] = "Failed to parse response";
LOG_WARN << "Failed to parse response: " << response.body;
callback(std::move(status), std::move(error));
return;
}
Expand All @@ -626,15 +639,19 @@ void RemoteEngine::HandleChatCompletion(

try {
response_json["stream"] = false;
if (!response_json.isMember("model")) {
response_json["model"] = model;
}
response_str = renderer_.Render(template_str, response_json);
} catch (const std::exception& e) {
throw std::runtime_error("Template rendering error: " +
std::string(e.what()));
}
} catch (const std::exception& e) {
// Log error and potentially rethrow or handle accordingly
LOG_WARN << "Error in TransformRequest: " << e.what();
LOG_WARN << "Using original request body";
LOG_WARN << "Error: " << e.what();
LOG_WARN << "Response: " << response.body;
LOG_WARN << "Using original body";
response_str = response_json.toStyledString();
}

Expand All @@ -649,6 +666,7 @@ void RemoteEngine::HandleChatCompletion(
Json::Value error;
error["error"] = "Failed to parse response";
callback(std::move(status), std::move(error));
LOG_WARN << "Failed to parse response: " << response_str;
return;
}

Expand Down
1 change: 1 addition & 0 deletions engine/extensions/remote-engine/remote_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct StreamContext {
std::string model;
extensions::TemplateRenderer& renderer;
std::string stream_template;
bool need_stop = true;
};
struct CurlResponse {
std::string body;
Expand Down
13 changes: 7 additions & 6 deletions engine/extensions/template_renderer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <regex>
#include <stdexcept>
#include "utils/logging_utils.h"
#include "utils/string_utils.h"
namespace extensions {

TemplateRenderer::TemplateRenderer() {
// Configure Inja environment
env_.set_trim_blocks(true);
Expand All @@ -21,7 +23,8 @@ TemplateRenderer::TemplateRenderer() {
const auto& value = *args[0];

if (value.is_string()) {
return nlohmann::json(std::string("\"") + value.get<std::string>() +
return nlohmann::json(std::string("\"") +
string_utils::EscapeJson(value.get<std::string>()) +
"\"");
}
return value;
Expand All @@ -46,16 +49,14 @@ std::string TemplateRenderer::Render(const std::string& tmpl,
std::string result = env_.render(tmpl, template_data);

// Clean up any potential double quotes in JSON strings
result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");
// result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");

LOG_DEBUG << "Result: " << result;

// Validate JSON
auto parsed = nlohmann::json::parse(result);

return result;
} catch (const std::exception& e) {
LOG_ERROR << "Template rendering failed: " << e.what();
LOG_ERROR << "Data: " << data.toStyledString();
LOG_ERROR << "Template: " << tmpl;
throw std::runtime_error(std::string("Template rendering failed: ") +
e.what());
Expand Down Expand Up @@ -133,4 +134,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path,
e.what());
}
}
} // namespace remote_engine
} // namespace extensions
82 changes: 49 additions & 33 deletions engine/services/inference_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
}
function_calling_utils::PreprocessRequest(json_body);
auto tool_choice = json_body->get("tool_choice", Json::Value::null);
auto model_id = json_body->get("model", "").asString();
if (saved_models_.find(model_id) != saved_models_.end()) {
// check if model is started, if not start it first
Json::Value root;
root["model"] = model_id;
root["engine"] = engine_type;
auto ir = GetModelStatus(std::make_shared<Json::Value>(root));
auto status = std::get<0>(ir)["status_code"].asInt();
if (status != drogon::k200OK) {
CTL_INF("Model is not loaded, start loading it: " << model_id);
auto res = LoadModel(saved_models_.at(model_id));
// ignore return result
}
}

auto engine_result = engine_service_->GetLoadedEngine(engine_type);
if (engine_result.has_error()) {
Json::Value res;
Expand All @@ -23,45 +38,42 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
LOG_WARN << "Engine is not loaded yet";
return cpp::fail(std::make_pair(stt, res));
}

if (!model_id.empty()) {
if (auto model_service = model_service_.lock()) {
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
if (metadata_ptr != nullptr &&
!metadata_ptr->tokenizer->chat_template.empty()) {
auto tokenizer = metadata_ptr->tokenizer;
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

{
auto model_id = json_body->get("model", "").asString();
if (!model_id.empty()) {
if (auto model_service = model_service_.lock()) {
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
if (metadata_ptr != nullptr &&
!metadata_ptr->tokenizer->chat_template.empty()) {
auto tokenizer = metadata_ptr->tokenizer;
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
tokenizer->chat_template, template_data_json,
tokenizer->bos_token, tokenizer->eos_token,
tokenizer->add_bos_token, tokenizer->add_eos_token,
tokenizer->add_generation_prompt);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
Json::Value stops(Json::arrayValue);
stops.append(tokenizer->eos_token);
(*json_body)["stop"] = stops;
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
tokenizer->chat_template, template_data_json, tokenizer->bos_token,
tokenizer->eos_token, tokenizer->add_bos_token,
tokenizer->add_eos_token, tokenizer->add_generation_prompt);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
Json::Value stops(Json::arrayValue);
stops.append(tokenizer->eos_token);
(*json_body)["stop"] = stops;
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
}
}
}

CTL_INF("Json body inference: " + json_body->toStyledString());

CTL_DBG("Json body inference: " + json_body->toStyledString());

auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
Expand Down Expand Up @@ -205,6 +217,10 @@ InferResult InferenceService::LoadModel(
std::get<RemoteEngineI*>(engine_result.value())
->LoadModel(json_body, std::move(cb));
}
if (!engine_service_->IsRemoteEngine(engine_type)) {
auto model_id = json_body->get("model", "").asString();
saved_models_[model_id] = json_body;
}
return std::make_pair(stt, r);
}

Expand Down
4 changes: 3 additions & 1 deletion engine/services/inference_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class InferenceService {

cpp::result<void, InferResult> HandleRouteRequest(
std::shared_ptr<SyncQueue> q, std::shared_ptr<Json::Value> json_body);

InferResult LoadModel(std::shared_ptr<Json::Value> json_body);

InferResult UnloadModel(const std::string& engine,
Expand All @@ -74,4 +74,6 @@ class InferenceService {
private:
std::shared_ptr<EngineService> engine_service_;
std::weak_ptr<ModelService> model_service_;
using SavedModel = std::shared_ptr<Json::Value>;
std::unordered_map<std::string, SavedModel> saved_models_;
};
Loading
Loading