2020
2121namespace {
2222
23+ const std::unordered_set<std::string> kIgnoredParams = {
24+ " model" , " model_alias" , " embedding" , " ai_prompt" ,
25+ " ai_template" , " prompt_template" , " mmproj" , " system_prompt" ,
26+ " created" , " stream" , " name" , " os" ,
27+ " owned_by" , " files" , " gpu_arch" , " quantization_method" ,
28+ " engine" , " system_template" , " max_tokens" , " user_template" ,
29+ " user_prompt" , " min_keep" , " mirostat" , " mirostat_eta" ,
30+ " mirostat_tau" , " text_model" , " version" , " n_probs" ,
31+ " object" , " penalize_nl" , " precision" , " size" ,
32+ " stop" , " tfs_z" , " typ_p" };
33+
34+ const std::unordered_map<std::string, std::string> kParamsMap = {
35+ {" cpu_threads" , " --threads" },
36+ {" n_ubatch" , " --ubatch-size" },
37+ {" n_batch" , " --batch-size" },
38+ {" n_parallel" , " --parallel" },
39+ {" temperature" , " --temp" },
40+ {" top_k" , " --top-k" },
41+ {" top_p" , " --top-p" },
42+ {" min_p" , " --min-p" },
43+ {" dynatemp_exponent" , " --dynatemp-exp" },
44+ {" ctx_len" , " --ctx-size" },
45+ {" ngl" , " -ngl" },
46+ };
47+
2348constexpr const int k200OK = 200 ;
2449constexpr const int k400BadRequest = 400 ;
2550constexpr const int k409Conflict = 409 ;
@@ -335,9 +360,9 @@ Json::Value ParseJsonString(const std::string& json_str) {
335360void LlamaEngine::Load (EngineLoadOption opts) {
336361 load_opt_ = opts;
337362 LOG_DEBUG << " Loading engine.." ;
338-
339363 LOG_DEBUG << " Is custom engine path: " << opts.is_custom_engine_path ;
340364 LOG_DEBUG << " Engine path: " << opts.engine_path .string ();
365+ LOG_DEBUG << " Log path: " << opts.log_path .string ();
341366
342367 SetFileLogger (opts.max_log_lines , opts.log_path .string ());
343368 SetLogLevel (opts.log_level );
@@ -351,6 +376,9 @@ void LlamaEngine::Unload(EngineUnloadOption opts) {
351376
352377LlamaEngine::LlamaEngine (int log_option) {
353378 trantor::Logger::setLogLevel (trantor::Logger::kInfo );
379+ if (log_option == kFileLoggerOption ) {
380+ async_file_logger_ = std::make_unique<trantor::FileLogger>();
381+ }
354382
355383 common_log_pause (common_log_main ());
356384
@@ -377,6 +405,7 @@ LlamaEngine::~LlamaEngine() {
377405 l.ReleaseResources ();
378406 }
379407 server_map_.clear ();
408+ async_file_logger_.reset ();
380409
381410 LOG_INFO << " LlamaEngine destructed successfully" ;
382411}
@@ -513,6 +542,15 @@ void LlamaEngine::GetModelStatus(std::shared_ptr<Json::Value> json_body,
513542
514543 auto model_id = llama_utils::GetModelId (*json_body);
515544 if (auto is_loaded = CheckModelLoaded (callback, model_id); is_loaded) {
545+ if (IsLlamaServerModel (model_id)) {
546+ Json::Value json_resp;
547+ json_resp[" model_loaded" ] = is_loaded;
548+ callback (ResStatus (IsDone{true }, HasError{false }, IsStream{false },
549+ StatusCode{k200OK})
550+ .ToJson (),
551+ std::move (json_resp));
552+ return ;
553+ }
516554 // CheckModelLoaded gurantees that model_id exists in server_ctx_map;
517555 auto si = server_map_.find (model_id);
518556 Json::Value json_resp;
@@ -567,17 +605,21 @@ void LlamaEngine::StopInferencing(const std::string& model_id) {
567605
568606void LlamaEngine::SetFileLogger (int max_log_lines,
569607 const std::string& log_path) {
608+ if (!async_file_logger_) {
609+ async_file_logger_ = std::make_unique<trantor::FileLogger>();
610+ }
611+
612+ async_file_logger_->setFileName (log_path);
613+ async_file_logger_->setMaxLines (max_log_lines); // Keep last 100000 lines
614+ async_file_logger_->startLogging ();
570615 trantor::Logger::setOutputFunction (
571616 [&](const char * msg, const uint64_t len) {
572- if (load_opt_.logger ) {
573- if (auto l = static_cast <trantor::FileLogger*>(load_opt_.logger ); l) {
574- l->output_ (msg, len);
575- }
576- }
617+ if (async_file_logger_)
618+ async_file_logger_->output_ (msg, len);
577619 },
578620 [&]() {
579- if (load_opt_. logger )
580- load_opt_. logger ->flush ();
621+ if (async_file_logger_ )
622+ async_file_logger_ ->flush ();
581623 });
582624 llama_log_set (
583625 [](ggml_log_level level, const char * text, void * user_data) {
@@ -607,7 +649,7 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> json_body) {
607649 }
608650
609651 // Spawn llama.cpp server only if it is chat model
610- if (!json_body->isMember (" mmproj" )) {
652+ if (!json_body->isMember (" mmproj" ) || (*json_body)[ " mmproj " ]. isNull () ) {
611653 return SpawnLlamaServer (*json_body);
612654 }
613655 common_params params;
@@ -698,21 +740,21 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> json_body) {
698740 params.cache_type_k = kv_cache_type_from_str (cache_type_k);
699741 }
700742 params.cache_type_v = params.cache_type_k ;
701- LOG_DEBUG << " cache_type: " << params.cache_type_k ;
743+ LOG_INFO << " cache_type: " << params.cache_type_k ;
702744
703745 auto fa = json_body->get (" flash_attn" , true ).asBool ();
704746 auto force_enable_fa = params.cache_type_k != GGML_TYPE_F16;
705747 if (force_enable_fa) {
706- LOG_DEBUG << " Using KV cache quantization, force enable Flash Attention" ;
748+ LOG_INFO << " Using KV cache quantization, force enable Flash Attention" ;
707749 }
708750 params.flash_attn = fa || force_enable_fa;
709751 if (params.flash_attn ) {
710- LOG_DEBUG << " Enabled Flash Attention" ;
752+ LOG_INFO << " Enabled Flash Attention" ;
711753 }
712754
713755 params.use_mmap = json_body->get (" use_mmap" , true ).asBool ();
714756 if (!params.use_mmap ) {
715- LOG_DEBUG << " Disabled mmap" ;
757+ LOG_INFO << " Disabled mmap" ;
716758 }
717759 params.n_predict = json_body->get (" n_predict" , -1 ).asInt ();
718760 params.prompt = json_body->get (" prompt" , " " ).asString ();
@@ -732,7 +774,7 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> json_body) {
732774 server_map_[model_id].repeat_last_n =
733775 json_body->get (" repeat_last_n" , 32 ).asInt ();
734776 server_map_[model_id].stop_words = (*json_body)[" stop" ];
735- LOG_DEBUG << " stop: " << server_map_[model_id].stop_words .toStyledString ();
777+ LOG_INFO << " stop: " << server_map_[model_id].stop_words .toStyledString ();
736778
737779 if (!json_body->operator [](" llama_log_folder" ).isNull ()) {
738780 common_log_resume (common_log_main ());
@@ -1337,7 +1379,7 @@ bool LlamaEngine::HasForceStopInferenceModel(const std::string& id) const {
13371379
13381380bool LlamaEngine::SpawnLlamaServer (const Json::Value& json_params) {
13391381 auto wait_for_server_up = [](const std::string& host, int port) {
1340- for (size_t i = 0 ; i < 120 ; i++) {
1382+ for (size_t i = 0 ; i < 10 ; i++) {
13411383 httplib::Client cli (host + " :" + std::to_string (port));
13421384 auto res = cli.Get (" /health" );
13431385 if (res && res->status == httplib::StatusCode::OK_200) {
@@ -1385,7 +1427,7 @@ bool LlamaEngine::SpawnLlamaServer(const Json::Value& json_params) {
13851427 std::string exe_w = " llama-server.exe" ;
13861428 std::string wcmds =
13871429 load_opt_.engine_path .string () + " /" + exe_w + " " + params;
1388- LOG_DEBUG << " wcmds: " << wcmds;
1430+ LOG_INFO << " wcmds: " << wcmds;
13891431 std::vector<wchar_t > mutable_cmds (wcmds.begin (), wcmds.end ());
13901432 mutable_cmds.push_back (L' \0 ' );
13911433 // Create child process
@@ -1468,19 +1510,16 @@ std::string LlamaEngine::ConvertJsonToParams(const Json::Value& root) {
14681510
14691511 for (const auto & member : root.getMemberNames ()) {
14701512 if (member == " model_path" || member == " llama_model_path" ) {
1471- ss << " --model" << " " ;
1472- ss << " \" " << root[member].asString () << " \" " ;
1473- continue ;
1474- } else if (member == " model" || member == " model_alias" ||
1475- member == " embedding" ) {
1513+ if (!root[member].isNull ()) {
1514+ ss << " --model" << " " ;
1515+ ss << " \" " << root[member].asString () << " \" " ;
1516+ }
14761517 continue ;
1477- } else if (member == " ctx_len" ) {
1478- ss << " --ctx-size" << " " ;
1479- ss << " \" " << std::to_string (root[member].asInt ()) << " \" " ;
1518+ } else if (kIgnoredParams .find (member) != kIgnoredParams .end ()) {
14801519 continue ;
1481- } else if (member == " ngl " ) {
1482- ss << " -ngl " << " " ;
1483- ss << " \" " << std::to_string ( root[member].asInt ()) << " \ " " ;
1520+ } else if (kParamsMap . find ( member) != kParamsMap . end () ) {
1521+ ss << kParamsMap . at (member) << " " ;
1522+ ss << root[member].asString () << " " ;
14841523 continue ;
14851524 } else if (member == " model_type" ) {
14861525 if (root[member].asString () == " embedding" ) {
@@ -1494,6 +1533,8 @@ std::string LlamaEngine::ConvertJsonToParams(const Json::Value& root) {
14941533 ss << " \" " << root[member].asString () << " \" " ;
14951534 } else if (root[member].isInt ()) {
14961535 ss << root[member].asInt () << " " ;
1536+ } else if (root[member].isDouble ()) {
1537+ ss << root[member].asDouble () << " " ;
14971538 } else if (root[member].isArray ()) {
14981539 ss << " [" ;
14991540 bool first = true ;
@@ -1521,16 +1562,11 @@ std::vector<std::string> LlamaEngine::ConvertJsonToParamsVector(
15211562 res.push_back (" --model" );
15221563 res.push_back (root[member].asString ());
15231564 continue ;
1524- } else if (member == " model" || member == " model_alias" ||
1525- member == " embedding" ) {
1526- continue ;
1527- } else if (member == " ctx_len" ) {
1528- res.push_back (" --ctx-size" );
1529- res.push_back (std::to_string (root[member].asInt ()));
1565+ } else if (kIgnoredParams .find (member) != kIgnoredParams .end ()) {
15301566 continue ;
1531- } else if (member == " ngl " ) {
1532- res.push_back (" -ngl " );
1533- res.push_back (std::to_string ( root[member].asInt () ));
1567+ } else if (kParamsMap . find ( member) != kParamsMap . end () ) {
1568+ res.push_back (kParamsMap . at (member) );
1569+ res.push_back (root[member].asString ( ));
15341570 continue ;
15351571 } else if (member == " model_type" ) {
15361572 if (root[member].asString () == " embedding" ) {
@@ -1544,6 +1580,8 @@ std::vector<std::string> LlamaEngine::ConvertJsonToParamsVector(
15441580 res.push_back (root[member].asString ());
15451581 } else if (root[member].isInt ()) {
15461582 res.push_back (std::to_string (root[member].asInt ()));
1583+ } else if (root[member].isDouble ()) {
1584+ res.push_back (std::to_string (root[member].asDouble ()));
15471585 } else if (root[member].isArray ()) {
15481586 std::stringstream ss;
15491587 ss << " [" ;
0 commit comments