Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit d847779

Browse files
fix: check model status before inferencing (#1864)
Co-authored-by: vansangpfiev <sang@jan.ai>
1 parent 0746ec9 commit d847779

File tree

2 files changed

+52
-34
lines changed

2 files changed

+52
-34
lines changed

engine/services/inference_service.cc

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
1414
}
1515
function_calling_utils::PreprocessRequest(json_body);
1616
auto tool_choice = json_body->get("tool_choice", Json::Value::null);
17+
auto model_id = json_body->get("model", "").asString();
18+
if (saved_models_.find(model_id) != saved_models_.end()) {
19+
// check if model is started, if not start it first
20+
Json::Value root;
21+
root["model"] = model_id;
22+
root["engine"] = engine_type;
23+
auto ir = GetModelStatus(std::make_shared<Json::Value>(root));
24+
auto status = std::get<0>(ir)["status_code"].asInt();
25+
if (status != drogon::k200OK) {
26+
CTL_INF("Model is not loaded, start loading it: " << model_id);
27+
auto res = LoadModel(saved_models_.at(model_id));
28+
// ignore return result
29+
}
30+
}
31+
1732
auto engine_result = engine_service_->GetLoadedEngine(engine_type);
1833
if (engine_result.has_error()) {
1934
Json::Value res;
@@ -23,45 +38,42 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2338
LOG_WARN << "Engine is not loaded yet";
2439
return cpp::fail(std::make_pair(stt, res));
2540
}
41+
42+
if (!model_id.empty()) {
43+
if (auto model_service = model_service_.lock()) {
44+
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
45+
if (metadata_ptr != nullptr &&
46+
!metadata_ptr->tokenizer->chat_template.empty()) {
47+
auto tokenizer = metadata_ptr->tokenizer;
48+
auto messages = (*json_body)["messages"];
49+
Json::Value messages_jsoncpp(Json::arrayValue);
50+
for (auto message : messages) {
51+
messages_jsoncpp.append(message);
52+
}
2653

27-
{
28-
auto model_id = json_body->get("model", "").asString();
29-
if (!model_id.empty()) {
30-
if (auto model_service = model_service_.lock()) {
31-
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
32-
if (metadata_ptr != nullptr &&
33-
!metadata_ptr->tokenizer->chat_template.empty()) {
34-
auto tokenizer = metadata_ptr->tokenizer;
35-
auto messages = (*json_body)["messages"];
36-
Json::Value messages_jsoncpp(Json::arrayValue);
37-
for (auto message : messages) {
38-
messages_jsoncpp.append(message);
39-
}
40-
41-
Json::Value tools(Json::arrayValue);
42-
Json::Value template_data_json;
43-
template_data_json["messages"] = messages_jsoncpp;
44-
// template_data_json["tools"] = tools;
45-
46-
auto prompt_result = jinja::RenderTemplate(
47-
tokenizer->chat_template, template_data_json,
48-
tokenizer->bos_token, tokenizer->eos_token,
49-
tokenizer->add_bos_token, tokenizer->add_eos_token,
50-
tokenizer->add_generation_prompt);
51-
if (prompt_result.has_value()) {
52-
(*json_body)["prompt"] = prompt_result.value();
53-
Json::Value stops(Json::arrayValue);
54-
stops.append(tokenizer->eos_token);
55-
(*json_body)["stop"] = stops;
56-
} else {
57-
CTL_ERR("Failed to render prompt: " + prompt_result.error());
58-
}
54+
Json::Value tools(Json::arrayValue);
55+
Json::Value template_data_json;
56+
template_data_json["messages"] = messages_jsoncpp;
57+
// template_data_json["tools"] = tools;
58+
59+
auto prompt_result = jinja::RenderTemplate(
60+
tokenizer->chat_template, template_data_json, tokenizer->bos_token,
61+
tokenizer->eos_token, tokenizer->add_bos_token,
62+
tokenizer->add_eos_token, tokenizer->add_generation_prompt);
63+
if (prompt_result.has_value()) {
64+
(*json_body)["prompt"] = prompt_result.value();
65+
Json::Value stops(Json::arrayValue);
66+
stops.append(tokenizer->eos_token);
67+
(*json_body)["stop"] = stops;
68+
} else {
69+
CTL_ERR("Failed to render prompt: " + prompt_result.error());
5970
}
6071
}
6172
}
6273
}
6374

64-
CTL_INF("Json body inference: " + json_body->toStyledString());
75+
76+
CTL_DBG("Json body inference: " + json_body->toStyledString());
6577

6678
auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
6779
if (!tool_choice.isNull()) {
@@ -205,6 +217,10 @@ InferResult InferenceService::LoadModel(
205217
std::get<RemoteEngineI*>(engine_result.value())
206218
->LoadModel(json_body, std::move(cb));
207219
}
220+
if (!engine_service_->IsRemoteEngine(engine_type)) {
221+
auto model_id = json_body->get("model", "").asString();
222+
saved_models_[model_id] = json_body;
223+
}
208224
return std::make_pair(stt, r);
209225
}
210226

engine/services/inference_service.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class InferenceService {
4747

4848
cpp::result<void, InferResult> HandleRouteRequest(
4949
std::shared_ptr<SyncQueue> q, std::shared_ptr<Json::Value> json_body);
50-
50+
5151
InferResult LoadModel(std::shared_ptr<Json::Value> json_body);
5252

5353
InferResult UnloadModel(const std::string& engine,
@@ -74,4 +74,6 @@ class InferenceService {
7474
private:
7575
std::shared_ptr<EngineService> engine_service_;
7676
std::weak_ptr<ModelService> model_service_;
77+
using SavedModel = std::shared_ptr<Json::Value>;
78+
std::unordered_map<std::string, SavedModel> saved_models_;
7779
};

0 commit comments

Comments
 (0)