Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 0cd4abe

Browse files
authored
Merge pull request #271 from janhq/103-feat-enable-llava-feature-in-nitro-1
feat: nitro multi modal
2 parents 4a3f958 + 58fa46c commit 0cd4abe

File tree

3 files changed

+94
-23
lines changed

3 files changed

+94
-23
lines changed

controllers/llamaCPP.cc

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <regex>
99
#include <string>
1010
#include <thread>
11+
#include <trantor/utils/Logger.h>
1112

1213
using namespace inferences;
1314
using json = nlohmann::json;
@@ -174,6 +175,7 @@ void llamaCPP::chatCompletion(
174175

175176
json data;
176177
json stopWords;
178+
int no_images = 0;
177179
// To set default value
178180

179181
if (jsonBody) {
@@ -200,29 +202,79 @@ void llamaCPP::chatCompletion(
200202
(*jsonBody).get("frequency_penalty", 0).asFloat();
201203
data["presence_penalty"] = (*jsonBody).get("presence_penalty", 0).asFloat();
202204
const Json::Value &messages = (*jsonBody)["messages"];
203-
for (const auto &message : messages) {
204-
std::string input_role = message["role"].asString();
205-
std::string role;
206-
if (input_role == "user") {
207-
role = user_prompt;
208-
std::string content = message["content"].asString();
209-
formatted_output += role + content;
210-
} else if (input_role == "assistant") {
211-
role = ai_prompt;
212-
std::string content = message["content"].asString();
213-
formatted_output += role + content;
214-
} else if (input_role == "system") {
215-
role = system_prompt;
216-
std::string content = message["content"].asString();
217-
formatted_output = role + content + formatted_output;
218205

219-
} else {
220-
role = input_role;
221-
std::string content = message["content"].asString();
222-
formatted_output += role + content;
206+
if (!llama.multimodal) {
207+
208+
for (const auto &message : messages) {
209+
std::string input_role = message["role"].asString();
210+
std::string role;
211+
if (input_role == "user") {
212+
role = user_prompt;
213+
std::string content = message["content"].asString();
214+
formatted_output += role + content;
215+
} else if (input_role == "assistant") {
216+
role = ai_prompt;
217+
std::string content = message["content"].asString();
218+
formatted_output += role + content;
219+
} else if (input_role == "system") {
220+
role = system_prompt;
221+
std::string content = message["content"].asString();
222+
formatted_output = role + content + formatted_output;
223+
224+
} else {
225+
role = input_role;
226+
std::string content = message["content"].asString();
227+
formatted_output += role + content;
228+
}
223229
}
230+
formatted_output += ai_prompt;
231+
} else {
232+
233+
data["image_data"] = json::array();
234+
for (const auto &message : messages) {
235+
std::string input_role = message["role"].asString();
236+
std::string role;
237+
if (input_role == "user") {
238+
formatted_output += role;
239+
for (auto content_piece : message["content"]) {
240+
role = user_prompt;
241+
242+
auto content_piece_type = content_piece["type"].asString();
243+
if (content_piece_type == "text") {
244+
auto text = content_piece["text"].asString();
245+
formatted_output += text;
246+
} else if (content_piece_type == "image_url") {
247+
auto image_url = content_piece["image_url"]["url"].asString();
248+
auto base64_image_data = nitro_utils::extractBase64(image_url);
249+
LOG_INFO << base64_image_data;
250+
formatted_output += "[img-" + std::to_string(no_images) + "]";
251+
252+
json content_piece_image_data;
253+
content_piece_image_data["data"] = base64_image_data;
254+
content_piece_image_data["id"] = no_images;
255+
data["image_data"].push_back(content_piece_image_data);
256+
no_images++;
257+
}
258+
}
259+
260+
} else if (input_role == "assistant") {
261+
role = ai_prompt;
262+
std::string content = message["content"].asString();
263+
formatted_output += role + content;
264+
} else if (input_role == "system") {
265+
role = system_prompt;
266+
std::string content = message["content"].asString();
267+
formatted_output = role + content + formatted_output;
268+
269+
} else {
270+
role = input_role;
271+
std::string content = message["content"].asString();
272+
formatted_output += role + content;
273+
}
274+
}
275+
formatted_output += ai_prompt;
276+
LOG_INFO << formatted_output;
224277
}
225-
formatted_output += ai_prompt;
226278

227279
data["prompt"] = formatted_output;
228280
for (const auto &stop_word : (*jsonBody)["stop"]) {
@@ -386,6 +438,10 @@ bool llamaCPP::loadModelImpl(const Json::Value &jsonBody) {
386438
int drogon_thread = drogon::app().getThreadNum() - 1;
387439
LOG_INFO << "Drogon thread is:" << drogon_thread;
388440
if (jsonBody) {
441+
if (!jsonBody["mmproj"].isNull()) {
442+
LOG_INFO << "MMPROJ FILE detected, multi-model enabled!";
443+
params.mmproj = jsonBody["mmproj"].asString();
444+
}
389445
params.model = jsonBody["llama_model_path"].asString();
390446
params.n_gpu_layers = jsonBody.get("ngl", 100).asInt();
391447
params.n_ctx = jsonBody.get("ctx_len", 2048).asInt();

controllers/llamaCPP.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,7 +1834,7 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
18341834
public:
18351835
llamaCPP() {
18361836
// Some default values for now below
1837-
log_disable(); // Disable the log to file feature, reduce bloat for
1837+
// log_disable(); // Disable the log to file feature, reduce bloat for
18381838
// target
18391839
// system ()
18401840
std::vector<std::string> llama_models =
@@ -1877,8 +1877,9 @@ class llamaCPP : public drogon::HttpController<llamaCPP> {
18771877
METHOD_LIST_END
18781878
void chatCompletion(const HttpRequestPtr &req,
18791879
std::function<void(const HttpResponsePtr &)> &&callback);
1880-
void chatCompletionPrelight(const HttpRequestPtr &req,
1881-
std::function<void(const HttpResponsePtr &)> &&callback);
1880+
void chatCompletionPrelight(
1881+
const HttpRequestPtr &req,
1882+
std::function<void(const HttpResponsePtr &)> &&callback);
18821883
void embedding(const HttpRequestPtr &req,
18831884
std::function<void(const HttpResponsePtr &)> &&callback);
18841885
void loadModel(const HttpRequestPtr &req,

utils/nitro_utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <drogon/HttpResponse.h>
77
#include <iostream>
88
#include <ostream>
9+
#include <regex>
910
// Include platform-specific headers
1011
#ifdef _WIN32
1112
#include <winsock2.h>
@@ -18,6 +19,19 @@ namespace nitro_utils {
1819

1920
inline std::string models_folder = "./models";
2021

22+
inline std::string extractBase64(const std::string &input) {
23+
std::regex pattern("base64,(.*)");
24+
std::smatch match;
25+
26+
if (std::regex_search(input, match, pattern)) {
27+
std::string base64_data = match[1];
28+
base64_data = base64_data.substr(0, base64_data.length() - 1);
29+
return base64_data;
30+
}
31+
32+
return "";
33+
}
34+
2135
inline std::vector<std::string> listFilesInDir(const std::string &path) {
2236
std::vector<std::string> files;
2337

0 commit comments

Comments
 (0)