@@ -14,6 +14,21 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
1414 }
1515 function_calling_utils::PreprocessRequest (json_body);
1616 auto tool_choice = json_body->get (" tool_choice" , Json::Value::null);
17+ auto model_id = json_body->get (" model" , " " ).asString ();
18+ if (saved_models_.find (model_id) != saved_models_.end ()) {
19+ // check if model is started, if not start it first
20+ Json::Value root;
21+ root[" model" ] = model_id;
22+ root[" engine" ] = engine_type;
23+ auto ir = GetModelStatus (std::make_shared<Json::Value>(root));
24+ auto status = std::get<0 >(ir)[" status_code" ].asInt ();
25+ if (status != drogon::k200OK) {
26+ CTL_INF (" Model is not loaded, start loading it: " << model_id);
27+ auto res = LoadModel (saved_models_.at (model_id));
28+ // ignore return result
29+ }
30+ }
31+
1732 auto engine_result = engine_service_->GetLoadedEngine (engine_type);
1833 if (engine_result.has_error ()) {
1934 Json::Value res;
@@ -23,45 +38,42 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2338 LOG_WARN << " Engine is not loaded yet" ;
2439 return cpp::fail (std::make_pair (stt, res));
2540 }
41+
42+ if (!model_id.empty ()) {
43+ if (auto model_service = model_service_.lock ()) {
44+ auto metadata_ptr = model_service->GetCachedModelMetadata (model_id);
45+ if (metadata_ptr != nullptr &&
46+ !metadata_ptr->tokenizer ->chat_template .empty ()) {
47+ auto tokenizer = metadata_ptr->tokenizer ;
48+ auto messages = (*json_body)[" messages" ];
49+ Json::Value messages_jsoncpp (Json::arrayValue);
50+ for (auto message : messages) {
51+ messages_jsoncpp.append (message);
52+ }
2653
27- {
28- auto model_id = json_body->get (" model" , " " ).asString ();
29- if (!model_id.empty ()) {
30- if (auto model_service = model_service_.lock ()) {
31- auto metadata_ptr = model_service->GetCachedModelMetadata (model_id);
32- if (metadata_ptr != nullptr &&
33- !metadata_ptr->tokenizer ->chat_template .empty ()) {
34- auto tokenizer = metadata_ptr->tokenizer ;
35- auto messages = (*json_body)[" messages" ];
36- Json::Value messages_jsoncpp (Json::arrayValue);
37- for (auto message : messages) {
38- messages_jsoncpp.append (message);
39- }
40-
41- Json::Value tools (Json::arrayValue);
42- Json::Value template_data_json;
43- template_data_json[" messages" ] = messages_jsoncpp;
44- // template_data_json["tools"] = tools;
45-
46- auto prompt_result = jinja::RenderTemplate (
47- tokenizer->chat_template , template_data_json,
48- tokenizer->bos_token , tokenizer->eos_token ,
49- tokenizer->add_bos_token , tokenizer->add_eos_token ,
50- tokenizer->add_generation_prompt );
51- if (prompt_result.has_value ()) {
52- (*json_body)[" prompt" ] = prompt_result.value ();
53- Json::Value stops (Json::arrayValue);
54- stops.append (tokenizer->eos_token );
55- (*json_body)[" stop" ] = stops;
56- } else {
57- CTL_ERR (" Failed to render prompt: " + prompt_result.error ());
58- }
54+ Json::Value tools (Json::arrayValue);
55+ Json::Value template_data_json;
56+ template_data_json[" messages" ] = messages_jsoncpp;
57+ // template_data_json["tools"] = tools;
58+
59+ auto prompt_result = jinja::RenderTemplate (
60+ tokenizer->chat_template , template_data_json, tokenizer->bos_token ,
61+ tokenizer->eos_token , tokenizer->add_bos_token ,
62+ tokenizer->add_eos_token , tokenizer->add_generation_prompt );
63+ if (prompt_result.has_value ()) {
64+ (*json_body)[" prompt" ] = prompt_result.value ();
65+ Json::Value stops (Json::arrayValue);
66+ stops.append (tokenizer->eos_token );
67+ (*json_body)[" stop" ] = stops;
68+ } else {
69+ CTL_ERR (" Failed to render prompt: " + prompt_result.error ());
5970 }
6071 }
6172 }
6273 }
6374
64- CTL_INF (" Json body inference: " + json_body->toStyledString ());
75+
76+ CTL_DBG (" Json body inference: " + json_body->toStyledString ());
6577
6678 auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
6779 if (!tool_choice.isNull ()) {
@@ -205,6 +217,10 @@ InferResult InferenceService::LoadModel(
205217 std::get<RemoteEngineI*>(engine_result.value ())
206218 ->LoadModel (json_body, std::move (cb));
207219 }
220+ if (!engine_service_->IsRemoteEngine (engine_type)) {
221+ auto model_id = json_body->get (" model" , " " ).asString ();
222+ saved_models_[model_id] = json_body;
223+ }
208224 return std::make_pair (stt, r);
209225}
210226
0 commit comments